fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
f"Queue on event handler {self.route_name} is full! Discarding event {data}" f"Queue on event handler {self.route_name} is full! Discarding event {data}"
) )
@log_wrap(action='consumer', isolate=False) @log_wrap(action='consumer')
async def consumer(self): async def consumer(self):
while True: while True:
try: try:
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
) )
pass pass
@log_wrap(action='batch', isolate=False) @log_wrap(action='batch')
async def process_batch(self): async def process_batch(self):
logger.debug("Processing Batch") logger.debug("Processing Batch")
# TODO: copy syntax might be more efficient here # TODO: copy syntax might be more efficient here

View File

@@ -123,7 +123,7 @@ class AnalyticsServer:
log_action_stack.set(['Analytics']) log_action_stack.set(['Analytics'])
log_app.set(conf.analytics['appname']) log_app.set(conf.analytics['appname'])
async with await self.db.connect(): async with self.db.open():
await self.talk.connect() await self.talk.connect()
await self.attach_event_handlers() await self.attach_event_handlers()
self._snap_task = asyncio.create_task(self.snapshot_loop()) self._snap_task = asyncio.create_task(self.snapshot_loop())

View File

@@ -38,7 +38,7 @@ async def main():
intents.message_content = True intents.message_content = True
intents.presences = False intents.presences = False
async with await db.connect(): async with db.open():
version = await db.version() version = await db.version()
if version.version != DATA_VERSION: if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."

View File

@@ -57,16 +57,14 @@ class CoreCog(LionCog):
async def cog_load(self): async def cog_load(self):
# Fetch (and possibly create) core data rows. # Fetch (and possibly create) core data rows.
conn = await self.bot.db.get_connection() self.app_config = await self.data.AppConfig.fetch_or_create(appname)
async with conn.transaction(): self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.app_config = await self.data.AppConfig.fetch_or_create(appname) self.shard_data = await self.data.Shard.fetch_or_create(
self.bot_config = await self.data.BotConfig.fetch_or_create(appname) shardname,
self.shard_data = await self.data.Shard.fetch_or_create( appname=appname,
shardname, shard_id=self.bot.shard_id,
appname=appname, shard_count=self.bot.shard_count
shard_id=self.bot.shard_id, )
shard_count=self.bot.shard_count
)
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join') self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove') self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')

View File

@@ -5,6 +5,7 @@ from cachetools import TTLCache
import discord import discord
from meta import conf from meta import conf
from meta.logger import log_wrap
from data import Table, Registry, Column, RowModel, RegisterEnum from data import Table, Registry, Column, RowModel, RegisterEnum
from data.models import WeakCache from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp from data.columns import Integer, String, Bool, Timestamp
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
_timestamp = Timestamp() _timestamp = Timestamp()
@classmethod @classmethod
@log_wrap(action="Add Pending Coins")
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']: async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
""" """
Safely add pending coins to a list of members. Safely add pending coins to a list of members.
@@ -316,39 +318,40 @@ class CoreData(Registry, name="core"):
) )
) )
# TODO: Replace with copy syntax/query? # TODO: Replace with copy syntax/query?
conn = await cls.table.connector.get_connection() async with cls.table.connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain(*pending)) tuple(chain(*pending))
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return cls._make_rows(*rows) return cls._make_rows(*rows)
@classmethod @classmethod
@log_wrap(action='get_member_rank')
async def get_member_rank(cls, guildid, userid, untracked): async def get_member_rank(cls, guildid, userid, untracked):
""" """
Get the time and coin ranking for the given member, ignoring the provided untracked members. Get the time and coin ranking for the given member, ignoring the provided untracked members.
""" """
conn = await cls.table.connector.get_connection() async with cls.table.connector.connection() as conn:
async with conn.cursor() as curs: async with conn.cursor() as curs:
await curs.execute( await curs.execute(
""" """
SELECT SELECT
time_rank, coin_rank time_rank, coin_rank
FROM ( FROM (
SELECT SELECT
userid, userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank, row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals FROM members_totals
WHERE WHERE
guildid=%s AND userid NOT IN %s guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s ) AS guild_ranks WHERE userid=%s
""", """,
(guildid, tuple(untracked), userid) (guildid, tuple(untracked), userid)
) )
return (await curs.fetchone()) or (None, None) return (await curs.fetchone()) or (None, None)
class LionHook(RowModel): class LionHook(RowModel):
""" """

View File

@@ -1,6 +1,7 @@
# from enum import Enum # from enum import Enum
from typing import Optional from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry from .registry import Attachable, Registry
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
connector = registry._conn connector = registry._conn
if connector is None: if connector is None:
raise ValueError("Cannot initialise without connector!") raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection() connector.connect_hook(self.connection_hook)
if connection is None: # await connector.refresh_pool()
raise ValueError("Cannot Init without connection.") # The below may be somewhat dangerous
info = await EnumInfo.fetch(connection, self.name) # But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None: if info is None:
raise ValueError(f"Enum {self.name} not found in database.") raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items())) register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector: class Connector:
cursor_factory = AsyncLoggingCursor cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args): def __init__(self, conn_args):
self._conn_args = conn_args self._conn_args = conn_args
self.conn: psq.AsyncConnection = None self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = [] self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection: @property
def conn(self) -> Optional[psq.AsyncConnection]:
""" """
Get the current active connection. Convenience property for the current context connection.
This should never be cached outside of a transaction.
""" """
# TODO: Reconnection logic? return ctx_connection.get()
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
async def connect(self) -> psq.AsyncConnection: @conn.setter
logger.info("Establishing connection to database.", extra={'action': "Data Connect"}) def conn(self, conn: psq.AsyncConnection):
self.conn = await psq.AsyncConnection.connect( """
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
) )
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection: async def refresh_pool(self):
return await self.connect() """
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
""" """

View File

@@ -35,12 +35,13 @@ class Database(Connector):
""" """
Return the current schema version as a Version namedtuple. Return the current schema version as a Version namedtuple.
""" """
async with self.conn.cursor() as cursor: async with self.connection() as conn:
# Get last entry in version table, compare against desired version async with conn.cursor() as cursor:
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") # Get last entry in version table, compare against desired version
row = await cursor.fetchone() await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
if row: row = await cursor.fetchone()
return Version(row['version'], row['time'], row['author']) if row:
else: return Version(row['version'], row['time'], row['author'])
# No versions in the database else:
return Version(-1, None, None) # No versions in the database
return Version(-1, None, None)

View File

@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
if self.connector is None: if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.") raise ValueError("Cannot execute query without cursor, connection, or connector.")
else: else:
conn = await self.connector.get_connection() async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else: else:
conn = self.conn async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else: else:
data = await self._execute(cursor) data = await self._execute(cursor)
return data return data

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union from typing import Optional, Union
import asyncio
import discord import discord
from discord.ext import commands as cmds from discord.ext import commands as cmds
@@ -182,30 +183,34 @@ class Economy(LionCog):
# We may need to do a mass row create operation. # We may need to do a mass row create operation.
targetids = set(target.id for target in targets) targetids = set(target.id for target in targets)
if len(targets) > 1: if len(targets) > 1:
conn = await ctx.bot.db.get_connection() async def wrapper():
async with conn.transaction(): async with self.bot.db.connection() as conn:
# First fetch the members which currently exist self.bot.db.conn = conn
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id) async with conn.transaction():
query.select('userid').with_no_adapter() # First fetch the members which currently exist
if 2 * len(targets) < len(ctx.guild.members): query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
# More efficient to fetch the targets explicitly query.select('userid').with_no_adapter()
query.where(userid=list(targetids)) if 2 * len(targets) < len(ctx.guild.members):
existent_rows = await query # More efficient to fetch the targets explicitly
existentids = set(r['userid'] for r in existent_rows) query.where(userid=list(targetids))
existent_rows = await query
existentids = set(r['userid'] for r in existent_rows)
# Then check if any new userids need adding, and if so create them # Then check if any new userids need adding, and if so create them
new_ids = targetids.difference(existentids) new_ids = targetids.difference(existentids)
if new_ids: if new_ids:
# We use ON CONFLICT IGNORE here in case the users already exist. # We use ON CONFLICT IGNORE here in case the users already exist.
await self.bot.core.data.User.table.insert_many( await self.bot.core.data.User.table.insert_many(
('userid',), ('userid',),
*((id,) for id in new_ids) *((id,) for id in new_ids)
).on_conflict(ignore=True) ).on_conflict(ignore=True)
# TODO: Replace 0 here with the starting_coin value # TODO: Replace 0 here with the starting_coin value
await self.bot.core.data.Member.table.insert_many( await self.bot.core.data.Member.table.insert_many(
('guildid', 'userid', 'coins'), ('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids) *((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True) ).on_conflict(ignore=True)
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
await task
else: else:
# With only one target, we can take a simpler path, and make better use of local caches. # With only one target, we can take a simpler path, and make better use of local caches.
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id) await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
@@ -703,31 +708,34 @@ class Economy(LionCog):
# Alternative flow could be waiting until the target user presses accept # Alternative flow could be waiting until the target user presses accept
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
conn = await self.bot.db.get_connection() async def wrapped():
async with conn.transaction(): async with self.bot.db.connection() as conn:
# We do this in a transaction so that if something goes wrong, self.bot.db.conn = conn
# the coins deduction is rolled back atomicly async with conn.transaction():
balance = ctx.alion.data.coins # We do this in a transaction so that if something goes wrong,
if amount > balance: # the coins deduction is rolled back atomicly
await ctx.interaction.edit_original_response( balance = ctx.alion.data.coins
embed=error_embed( if amount > balance:
t(_p( await ctx.interaction.edit_original_response(
'cmd:send|error:insufficient', embed=error_embed(
"You do not have enough lioncoins to do this!\n" t(_p(
"`Current Balance:` {coin_emoji}{balance}" 'cmd:send|error:insufficient',
)).format( "You do not have enough lioncoins to do this!\n"
coin_emoji=self.bot.config.emojis.getemoji('coin'), "`Current Balance:` {coin_emoji}{balance}"
balance=balance )).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
balance=balance
)
),
) )
), return
)
return
# Transfer the coins # Transfer the coins
await ctx.alion.data.update(coins=(Member.coins - amount)) await ctx.alion.data.update(coins=(Member.coins - amount))
await target_lion.data.update(coins=(Member.coins + amount)) await target_lion.data.update(coins=(Member.coins + amount))
# TODO: Audit trail # TODO: Audit trail
await asyncio.create_task(wrapped(), name="wrapped-send")
# Message target # Message target
embed = discord.Embed( embed = discord.Embed(

View File

@@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
from data.columns import Integer, Bool, Column, Timestamp from data.columns import Integer, Bool, Column, Timestamp
from core.data import CoreData from core.data import CoreData
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
created_at = Timestamp() created_at = Timestamp()
@classmethod @classmethod
@log_wrap(action='execute_transaction')
async def execute_transaction( async def execute_transaction(
cls, cls,
transaction_type: TransactionType, transaction_type: TransactionType,
@@ -108,25 +110,27 @@ class EconomyData(Registry, name='economy'):
from_account: int, to_account: int, amount: int, bonus: int = 0, from_account: int, to_account: int, amount: int, bonus: int = 0,
refunds: int = None refunds: int = None
): ):
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
transaction = await cls.create( async with conn.transaction():
transactiontype=transaction_type, transaction = await cls.create(
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus, transactiontype=transaction_type,
from_account=from_account, to_account=to_account, guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
refunds=refunds from_account=from_account, to_account=to_account,
) refunds=refunds
if from_account is not None: )
await CoreData.Member.table.update_where( if from_account is not None:
guildid=guildid, userid=from_account await CoreData.Member.table.update_where(
).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus))) guildid=guildid, userid=from_account
if to_account is not None: ).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus)))
await CoreData.Member.table.update_where( if to_account is not None:
guildid=guildid, userid=to_account await CoreData.Member.table.update_where(
).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus))) guildid=guildid, userid=to_account
return transaction ).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus)))
return transaction
@classmethod @classmethod
@log_wrap(action='execute_transactions')
async def execute_transactions(cls, *transactions): async def execute_transactions(cls, *transactions):
""" """
Execute multiple transactions in one data transaction. Execute multiple transactions in one data transaction.
@@ -142,65 +146,68 @@ class EconomyData(Registry, name='economy'):
if not transactions: if not transactions:
return [] return []
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
# Create the transactions async with conn.transaction():
rows = await cls.table.insert_many( # Create the transactions
( rows = await cls.table.insert_many(
'transactiontype', (
'guildid', 'actorid', 'transactiontype',
'from_account', 'to_account', 'guildid', 'actorid',
'amount', 'bonus', 'from_account', 'to_account',
'refunds' 'amount', 'bonus',
), 'refunds'
*transactions ),
).with_adapter(cls._make_rows) *transactions
).with_adapter(cls._make_rows)
# Update the members # Update the members
transtable = TemporaryTable( transtable = TemporaryTable(
'_guildid', '_userid', '_amount', '_guildid', '_userid', '_amount',
types=('BIGINT', 'BIGINT', 'INTEGER') types=('BIGINT', 'BIGINT', 'INTEGER')
) )
values = transtable.values values = transtable.values
for transaction in transactions: for transaction in transactions:
_, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction _, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction
coins = amount + bonus coins = amount + bonus
if coins: if coins:
if from_acc: if from_acc:
values.append((guildid, from_acc, -1 * coins)) values.append((guildid, from_acc, -1 * coins))
if to_acc: if to_acc:
values.append((guildid, to_acc, coins)) values.append((guildid, to_acc, coins))
if values: if values:
Member = CoreData.Member Member = CoreData.Member
await Member.table.update_where( await Member.table.update_where(
guildid=transtable['_guildid'], userid=transtable['_userid'] guildid=transtable['_guildid'], userid=transtable['_userid']
).set( ).set(
coins=SAFECOINS(Member.coins + transtable['_amount']) coins=SAFECOINS(Member.coins + transtable['_amount'])
).from_expr(transtable) ).from_expr(transtable)
return rows return rows
@classmethod @classmethod
@log_wrap(action='refund_transactions')
async def refund_transactions(cls, *transactionids, actorid=0): async def refund_transactions(cls, *transactionids, actorid=0):
if not transactionids: if not transactionids:
return [] return []
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
# First fetch the transaction rows to refund async with conn.transaction():
data = await cls.table.select_where(transactionid=transactionids) # First fetch the transaction rows to refund
if data: data = await cls.table.select_where(transactionid=transactionids)
# Build the transaction refund data if data:
records = [ # Build the transaction refund data
( records = [
TransactionType.REFUND, (
tr['guildid'], actorid, TransactionType.REFUND,
tr['to_account'], tr['from_account'], tr['guildid'], actorid,
tr['amount'] + tr['bonus'], 0, tr['to_account'], tr['from_account'],
tr['transactionid'] tr['amount'] + tr['bonus'], 0,
) tr['transactionid']
for tr in data )
] for tr in data
# Execute refund transactions ]
return await cls.execute_transactions(*records) # Execute refund transactions
return await cls.execute_transactions(*records)
class ShopTransaction(RowModel): class ShopTransaction(RowModel):
""" """
@@ -217,19 +224,21 @@ class EconomyData(Registry, name='economy'):
itemid = Integer() itemid = Integer()
@classmethod @classmethod
@log_wrap(action='purchase_transaction')
async def purchase_transaction( async def purchase_transaction(
cls, cls,
guildid: int, actorid: int, guildid: int, actorid: int,
userid: int, itemid: int, amount: int userid: int, itemid: int, amount: int
): ):
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
row = await EconomyData.Transaction.execute_transaction( async with conn.transaction():
TransactionType.SHOP_PURCHASE, row = await EconomyData.Transaction.execute_transaction(
guildid=guildid, actorid=actorid, from_account=userid, to_account=None, TransactionType.SHOP_PURCHASE,
amount=amount guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
) amount=amount
return await cls.create(transactionid=row.transactionid, itemid=itemid) )
return await cls.create(transactionid=row.transactionid, itemid=itemid)
class TaskTransaction(RowModel): class TaskTransaction(RowModel):
""" """
@@ -263,19 +272,21 @@ class EconomyData(Registry, name='economy'):
return result[0]['recent'] or 0 return result[0]['recent'] or 0
@classmethod @classmethod
@log_wrap(action='reward_completed_tasks')
async def reward_completed(cls, userid, guildid, count, amount): async def reward_completed(cls, userid, guildid, count, amount):
""" """
Reward the specified member `amount` coins for completing `count` tasks. Reward the specified member `amount` coins for completing `count` tasks.
""" """
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog? # TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
row = await EconomyData.Transaction.execute_transaction( async with conn.transaction():
TransactionType.TASKS, row = await EconomyData.Transaction.execute_transaction(
guildid=guildid, actorid=userid, from_account=None, to_account=userid, TransactionType.TASKS,
amount=amount guildid=guildid, actorid=userid, from_account=None, to_account=userid,
) amount=amount
return await cls.create(transactionid=row.transactionid, count=count) )
return await cls.create(transactionid=row.transactionid, count=count)
class SessionTransaction(RowModel): class SessionTransaction(RowModel):
""" """

View File

@@ -193,18 +193,19 @@ class MemberAdminCog(LionCog):
await lion.data.update(last_left=utc_now()) await lion.data.update(last_left=utc_now())
# Save member roles # Save member roles
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
await self.data.past_roles.delete_where( async with conn.transaction():
guildid=member.guild.id, await self.data.past_roles.delete_where(
userid=member.id guildid=member.guild.id,
) userid=member.id
# Insert current member roles
if member.roles:
await self.data.past_roles.insert_many(
('guildid', 'userid', 'roleid'),
*((member.guild.id, member.id, role.id) for role in member.roles)
) )
# Insert current member roles
if member.roles:
await self.data.past_roles.insert_many(
('guildid', 'userid', 'roleid'),
*((member.guild.id, member.id, role.id) for role in member.roles)
)
logger.debug( logger.debug(
f"Stored persisting roles for member <uid:{member.id}> in <gid:{member.guild.id}>." f"Stored persisting roles for member <uid:{member.id}> in <gid:{member.guild.id}>."
) )

View File

@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
update_args[instance._column] = instance.data update_args[instance._column] = instance.data
ack_lines.append(instance.update_message) ack_lines.append(instance.update_message)
await ctx.lguild.data.update(**update_args)
# Do the ack # Do the ack
tick = self.bot.config.emojis.tick tick = self.bot.config.emojis.tick
embed = discord.Embed( embed = discord.Embed(

View File

@@ -483,65 +483,63 @@ class RoleMenu:
)).format(role=role.name) )).format(role=role.name)
) )
conn = await self.bot.db.get_connection() # Remove the role
async with conn.transaction(): try:
# Remove the role await member.remove_roles(role)
try: except discord.Forbidden:
await member.remove_roles(role) raise UserInputError(
except discord.Forbidden: t(_p(
raise UserInputError( 'rolemenu|deselect|error:perms',
t(_p( "I don't have enough permissions to remove this role from you!"
'rolemenu|deselect|error:perms',
"I don't have enough permissions to remove this role from you!"
))
)
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|deselect|error:discord',
"An unknown error occurred removing your role! Please try again later."
))
)
# Update history
now = utc_now()
history = await self.cog.data.RoleMenuHistory.table.update_where(
menuid=self.data.menuid,
roleid=role.id,
userid=member.id,
removed_at=None,
).set(removed_at=now)
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
# Refund if required
transactionids = [row['transactionid'] for row in history]
if self.config.refunds.value and any(transactionids):
transactionids = [tid for tid in transactionids if tid]
economy: Economy = self.bot.get_cog('Economy')
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
total_refund = sum(row.amount + row.bonus for row in refunded)
else:
total_refund = 0
# Ack the removal
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'rolemenu|deslect|success|title',
"Role removed"
)) ))
) )
if total_refund: except discord.HTTPException:
embed.description = t(_p( raise UserInputError(
'rolemenu|deselect|success:refund|desc', t(_p(
"You have removed **{role}**, and been refunded {coin} **{amount}**." 'rolemenu|deselect|error:discord',
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund) "An unknown error occurred removing your role! Please try again later."
else: ))
embed.description = t(_p( )
'rolemenu|deselect|success:norefund|desc',
"You have unequipped **{role}**." # Update history
)).format(role=role.name) now = utc_now()
return embed history = await self.cog.data.RoleMenuHistory.table.update_where(
menuid=self.data.menuid,
roleid=role.id,
userid=member.id,
removed_at=None,
).set(removed_at=now)
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
# Refund if required
transactionids = [row['transactionid'] for row in history]
if self.config.refunds.value and any(transactionids):
transactionids = [tid for tid in transactionids if tid]
economy: Economy = self.bot.get_cog('Economy')
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
total_refund = sum(row.amount + row.bonus for row in refunded)
else:
total_refund = 0
# Ack the removal
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'rolemenu|deslect|success|title',
"Role removed"
))
)
if total_refund:
embed.description = t(_p(
'rolemenu|deselect|success:refund|desc',
"You have removed **{role}**, and been refunded {coin} **{amount}**."
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund)
else:
embed.description = t(_p(
'rolemenu|deselect|success:norefund|desc',
"You have unequipped **{role}**."
)).format(role=role.name)
return embed
else: else:
# Member does not have the role, selection case. # Member does not have the role, selection case.
required = self.config.required_role.data required = self.config.required_role.data
@@ -591,57 +589,55 @@ class RoleMenu:
) )
) )
conn = await self.bot.db.get_connection() try:
async with conn.transaction(): await member.add_roles(role)
try: except discord.Forbidden:
await member.add_roles(role) raise UserInputError(
except discord.Forbidden: t(_p(
raise UserInputError( 'rolemenu|select|error:perms',
t(_p( "I don't have enough permissions to give you this role!"
'rolemenu|select|error:perms', ))
"I don't have enough permissions to give you this role!"
))
)
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|select|error:discord',
"An unknown error occurred while assigning your role! "
"Please try again later."
))
)
now = utc_now()
# Create transaction if applicable
if price:
economy: Economy = self.bot.get_cog('Economy')
tx = await economy.data.Transaction.execute_transaction(
transaction_type=TransactionType.OTHER,
guildid=guild.id, actorid=member.id,
from_account=member.id, to_account=None,
amount=price
)
tid = tx.transactionid
else:
tid = None
# Calculate expiry
duration = mrole.config.duration.value
if duration is not None:
expiry = now + dt.timedelta(seconds=duration)
else:
expiry = None
# Add to equip history
equip = await self.cog.data.RoleMenuHistory.create(
menuid=self.data.menuid, roleid=role.id,
userid=member.id,
obtained_at=now,
transactionid=tid,
expires_at=expiry
) )
await self.cog.schedule_expiring(equip) except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|select|error:discord',
"An unknown error occurred while assigning your role! "
"Please try again later."
))
)
now = utc_now()
# Create transaction if applicable
if price:
economy: Economy = self.bot.get_cog('Economy')
tx = await economy.data.Transaction.execute_transaction(
transaction_type=TransactionType.OTHER,
guildid=guild.id, actorid=member.id,
from_account=member.id, to_account=None,
amount=price
)
tid = tx.transactionid
else:
tid = None
# Calculate expiry
duration = mrole.config.duration.value
if duration is not None:
expiry = now + dt.timedelta(seconds=duration)
else:
expiry = None
# Add to equip history
equip = await self.cog.data.RoleMenuHistory.create(
menuid=self.data.menuid, roleid=role.id,
userid=member.id,
obtained_at=now,
transactionid=tid,
expires_at=expiry
)
await self.cog.schedule_expiring(equip)
# Ack the selection # Ack the selection
embed = discord.Embed( embed = discord.Embed(

View File

@@ -259,17 +259,6 @@ class RoomCog(LionCog):
lguild, lguild,
[member.id for member in members] [member.id for member in members]
) )
self._start(room)
# Send tips message
# TODO: Actual tips.
await channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
await ui.send(channel)
except Exception: except Exception:
try: try:
await channel.delete(reason="Failed to created private room") await channel.delete(reason="Failed to created private room")
@@ -454,71 +443,9 @@ class RoomCog(LionCog):
return return
# Positive response. Start a transaction. # Positive response. Start a transaction.
conn = await self.bot.db.get_connection() room = await self._do_create_room(ctx, required, days, rent, name, provided)
async with conn.transaction():
# Check member balance is sufficient
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < required:
await ctx.reply(
embed=error_embed(
t(_np(
'cmd:room_rent|error:insufficient_funds',
"Renting a private room for `one` day costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
"Renting a private room for `{days}` days costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
days
)).format(
coin=self.bot.config.emojis.coin,
balance=member_balance,
required=required,
days=days
),
ephemeral=True
)
)
return
# Deduct balance
# TODO: Economy transaction instead of manual deduction
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
# Create room with given starting balance and other parameters
try:
room = await self.create_private_room(
ctx.guild,
ctx.author,
required - rent,
name or ctx.author.display_name,
members=provided
)
except discord.Forbidden:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:my_permissions',
"Could not create your private room! You were not charged.\n"
"I have insufficient permissions to create a private room channel."
)),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
except discord.HTTPException as e:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:unknown',
"Could not create your private room! You were not charged.\n"
"An unknown error occurred while creating your private room.\n"
"`{error}`"
)).format(error=e.text),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
if room:
# Ack with confirmation message pointing to the room # Ack with confirmation message pointing to the room
msg = t(_p( msg = t(_p(
'cmd:room_rent|success', 'cmd:room_rent|success',
@@ -531,6 +458,90 @@ class RoomCog(LionCog):
description=msg description=msg
) )
) )
self._start(room)
# Send tips message
# TODO: Actual tips.
await room.channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
mention=ctx.author.mention
)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
await ui.send(room.channel)
@log_wrap(action='create_room')
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
t = self.bot.translator.t
# TODO: Rollback the channel create if this fails
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
# Note that the room creation will go into the UI as well.
async with conn.transaction():
# Check member balance is sufficient
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < required:
await ctx.reply(
embed=error_embed(
t(_np(
'cmd:room_rent|error:insufficient_funds',
"Renting a private room for `one` day costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
"Renting a private room for `{days}` days costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
days
)).format(
coin=self.bot.config.emojis.coin,
balance=member_balance,
required=required,
days=days
),
ephemeral=True
)
)
return
# Deduct balance
# TODO: Economy transaction instead of manual deduction
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
# Create room with given starting balance and other parameters
try:
return await self.create_private_room(
ctx.guild,
ctx.author,
required - rent,
name or ctx.author.display_name,
members=provided
)
except discord.Forbidden:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:my_permissions',
"Could not create your private room! You were not charged.\n"
"I have insufficient permissions to create a private room channel."
)),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
except discord.HTTPException as e:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:unknown',
"Could not create your private room! You were not charged.\n"
"An unknown error occurred while creating your private room.\n"
"`{error}`"
)).format(error=e.text),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
@room_group.command( @room_group.command(
name=_p('cmd:room_status', "status"), name=_p('cmd:room_status', "status"),
@@ -864,43 +875,41 @@ class RoomCog(LionCog):
return return
# Start Transaction # Start Transaction
conn = await self.bot.db.get_connection() # TODO: Economy transaction
async with conn.transaction(): await ctx.alion.data.refresh()
await ctx.alion.data.refresh() member_balance = ctx.alion.data.coins
member_balance = ctx.alion.data.coins if member_balance < coins:
if member_balance < coins: await ctx.reply(
await ctx.reply( embed=error_embed(t(_p(
embed=error_embed(t(_p( 'cmd:room_deposit|error:insufficient_funds',
'cmd:room_deposit|error:insufficient_funds', "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." )).format(
)).format( coin=self.bot.config.emojis.coin,
coin=self.bot.config.emojis.coin, amount=coins,
amount=coins, balance=member_balance
balance=member_balance )),
)), ephemeral=True
ephemeral=True )
) return
return
# Deduct balance # Deduct balance
# TODO: Economy transaction await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
await ctx.alion.data.update(coins=CoreData.Member.coins - coins) await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
# Post deposit message # Post deposit message
await room.notify_deposit(ctx.author, coins) await room.notify_deposit(ctx.author, coins)
# Ack the deposit # Ack the deposit
if ctx.channel.id != room.data.channelid: if ctx.channel.id != room.data.channelid:
ack_msg = t(_p( ack_msg = t(_p(
'cmd:room_depost|success', 'cmd:room_depost|success',
"Success! You have contributed {coin}**{amount}** to the private room bank." "Success! You have contributed {coin}**{amount}** to the private room bank."
)).format(coin=self.bot.config.emojis.coin, amount=coins) )).format(coin=self.bot.config.emojis.coin, amount=coins)
await ctx.reply( await ctx.reply(
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg) embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg)
) )
else: else:
await ctx.interaction.delete_original_response() await ctx.interaction.delete_original_response()
# ----- Guild Configuration ----- # ----- Guild Configuration -----
@LionCog.placeholder_group @LionCog.placeholder_group

View File

@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
from meta import LionBot, conf from meta import LionBot, conf
from meta.errors import UserInputError from meta.errors import UserInputError
from meta.logger import log_wrap
from babel.translator import ctx_locale from babel.translator import ctx_locale
from utils.lib import utc_now, MessageArgs, error_embed from utils.lib import utc_now, MessageArgs, error_embed
from utils.ui import MessageUI, input from utils.ui import MessageUI, input
@@ -115,38 +116,43 @@ class RoomUI(MessageUI):
return return
await submit.response.defer(thinking=True, ephemeral=True) await submit.response.defer(thinking=True, ephemeral=True)
await self._do_deposit(t, press, amount, submit)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
@log_wrap(isolate=True)
async def _do_deposit(self, t, press, amount, submit):
# Start transaction for deposit # Start transaction for deposit
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
# Get the lion balance directly async with conn.transaction():
lion = await self.bot.core.data.Member.fetch( # Get the lion balance directly
self.room.data.guildid, lion = await self.bot.core.data.Member.fetch(
press.user.id, self.room.data.guildid,
cached=False press.user.id,
) cached=False
balance = lion.coins )
if balance < amount: balance = lion.coins
await submit.edit_original_response( if balance < amount:
embed=error_embed( await submit.edit_original_response(
t(_p( embed=error_embed(
'ui:room_status|button:deposit|error:insufficient_funds', t(_p(
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**." 'ui:room_status|button:deposit|error:insufficient_funds',
)).format( "You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
coin=self.bot.config.emojis.coin, )).format(
amount=amount, coin=self.bot.config.emojis.coin,
balance=balance amount=amount,
balance=balance
)
) )
) )
) return
return # TODO: Economy Transaction
# TODO: Economy Transaction await lion.update(coins=CoreData.Member.coins - amount)
await lion.update(coins=CoreData.Member.coins - amount) await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
async def desposit_button_refresh(self): async def desposit_button_refresh(self):
self.desposit_button.label = self.bot.translator.t(_p( self.desposit_button.label = self.bot.translator.t(_p(

View File

@@ -217,23 +217,24 @@ class ScheduleCog(LionCog):
for bookingid in bookingids: for bookingid in bookingids:
await self._cancel_booking_active(*bookingid) await self._cancel_booking_active(*bookingid)
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
# Now delete from data async with conn.transaction():
records = await self.data.ScheduleSessionMember.table.delete_where( # Now delete from data
MULTIVALUE_IN( records = await self.data.ScheduleSessionMember.table.delete_where(
('slotid', 'guildid', 'userid'), MULTIVALUE_IN(
*bookingids ('slotid', 'guildid', 'userid'),
*bookingids
)
) )
)
# Refund cancelled bookings # Refund cancelled bookings
if refund: if refund:
maybe_tids = (record['book_transactionid'] for record in records) maybe_tids = (record['book_transactionid'] for record in records)
tids = [tid for tid in maybe_tids if tid is not None] tids = [tid for tid in maybe_tids if tid is not None]
if tids: if tids:
economy = self.bot.get_cog('Economy') economy = self.bot.get_cog('Economy')
await economy.data.Transaction.refund_transactions(*tids) await economy.data.Transaction.refund_transactions(*tids)
finally: finally:
for lock in locks: for lock in locks:
lock.release() lock.release()
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
"One or more requested timeslots are already booked!" "One or more requested timeslots are already booked!"
)) ))
raise UserInputError(error) raise UserInputError(error)
conn = await self.bot.db.get_connection()
# Booking request is now validated. Perform bookings. # Booking request is now validated. Perform bookings.
# Fetch or create session data # Fetch or create session data
@@ -482,27 +482,27 @@ class ScheduleCog(LionCog):
*((guildid, slotid) for slotid in slotids) *((guildid, slotid) for slotid in slotids)
) )
async with conn.transaction(): # Create transactions
# Create transactions # TODO: wrap in a transaction so the economy transaction gets unwound if it fails
economy = self.bot.get_cog('Economy') economy = self.bot.get_cog('Economy')
trans_data = ( trans_data = (
TransactionType.SCHEDULE_BOOK, TransactionType.SCHEDULE_BOOK,
guildid, userid, userid, 0, guildid, userid, userid, 0,
config.get(ScheduleSettings.ScheduleCost.setting_id).value, config.get(ScheduleSettings.ScheduleCost.setting_id).value,
0, None 0, None
) )
transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids)) transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids))
transactionids = [row.transactionid for row in transactions] transactionids = [row.transactionid for row in transactions]
# Create bookings # Create bookings
now = utc_now() now = utc_now()
booking_data = await self.data.ScheduleSessionMember.table.insert_many( booking_data = await self.data.ScheduleSessionMember.table.insert_many(
('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'), ('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'),
*( *(
(guildid, userid, slotid, now, tid) (guildid, userid, slotid, now, tid)
for slotid, tid in zip(slotids, transactionids) for slotid, tid in zip(slotids, transactionids)
)
) )
)
# Now pass to activated slots # Now pass to activated slots
for record in booking_data: for record in booking_data:

View File

@@ -356,77 +356,76 @@ class TimeSlot:
Does not modify session room channels (responsibility of the next open). Does not modify session room channels (responsibility of the next open).
""" """
try: try:
conn = await self.bot.db.get_connection() # TODO: Transaction?
async with conn.transaction(): # Calculate rewards
# Calculate rewards rewards = []
rewards = [] attendance = []
attendance = [] did_not_show = []
did_not_show = [] for session in sessions:
for session in sessions: bonus = session.bonus_reward * session.all_attended
bonus = session.bonus_reward * session.all_attended reward = session.attended_reward + bonus
reward = session.attended_reward + bonus required = session.min_attendence
required = session.min_attendence for member in session.members.values():
for member in session.members.values(): guildid = member.guildid
guildid = member.guildid userid = member.userid
userid = member.userid attended = (member.total_clock >= required)
attended = (member.total_clock >= required) if attended:
if attended: rewards.append(
rewards.append( (TransactionType.SCHEDULE_REWARD,
(TransactionType.SCHEDULE_REWARD, guildid, self.bot.user.id,
guildid, self.bot.user.id, 0, userid,
0, userid, reward, 0,
reward, 0, None)
None)
)
else:
did_not_show.append((guildid, userid))
attendance.append(
(self.slotid, guildid, userid, attended, member.total_clock)
) )
else:
did_not_show.append((guildid, userid))
# Perform economy transactions attendance.append(
economy: Economy = self.bot.get_cog('Economy') (self.slotid, guildid, userid, attended, member.total_clock)
transactions = await economy.data.Transaction.execute_transactions(*rewards)
reward_ids = {
(t.guildid, t.to_account): t.transactionid
for t in transactions
}
# Update lobby messages
message_tasks = [
asyncio.create_task(session.update_status(save=False))
for session in sessions
if session.lobby_channel is not None
]
await asyncio.gather(*message_tasks)
# Save attendance
if attendance:
att_table = TemporaryTable(
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
) )
att_table.values = [
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
for sid, gid, uid, att, clock in attendance
]
await self.data.ScheduleSessionMember.table.update_where(
slotid=att_table['_sid'],
guildid=att_table['_gid'],
userid=att_table['_uid'],
).set(
attended=att_table['_att'],
clock=att_table['_clock'],
reward_transactionid=att_table['_reward']
).from_expr(att_table)
# Mark guild sessions as closed # Perform economy transactions
if sessions: economy: Economy = self.bot.get_cog('Economy')
await self.data.ScheduleSession.table.update_where( transactions = await economy.data.Transaction.execute_transactions(*rewards)
slotid=self.slotid, reward_ids = {
guildid=list(session.guildid for session in sessions) (t.guildid, t.to_account): t.transactionid
).set(closed_at=utc_now()) for t in transactions
}
# Update lobby messages
message_tasks = [
asyncio.create_task(session.update_status(save=False))
for session in sessions
if session.lobby_channel is not None
]
await asyncio.gather(*message_tasks)
# Save attendance
if attendance:
att_table = TemporaryTable(
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
)
att_table.values = [
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
for sid, gid, uid, att, clock in attendance
]
await self.data.ScheduleSessionMember.table.update_where(
slotid=att_table['_sid'],
guildid=att_table['_gid'],
userid=att_table['_uid'],
).set(
attended=att_table['_att'],
clock=att_table['_clock'],
reward_transactionid=att_table['_reward']
).from_expr(att_table)
# Mark guild sessions as closed
if sessions:
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid,
guildid=list(session.guildid for session in sessions)
).set(closed_at=utc_now())
if consequences and did_not_show: if consequences and did_not_show:
# Trigger blacklist and cancel member bookings as needed # Trigger blacklist and cancel member bookings as needed
@@ -532,36 +531,35 @@ class TimeSlot:
This involves refunding the booking transactions, deleting the booking rows, This involves refunding the booking transactions, deleting the booking rows,
and updating any messages that may have been posted. and updating any messages that may have been posted.
""" """
conn = await self.bot.db.get_connection() # TODO: Transaction
async with conn.transaction(): # Collect booking rows
# Collect booking rows bookings = [member.data for session in sessions for member in session.members.values()]
bookings = [member.data for session in sessions for member in session.members.values()]
if bookings: if bookings:
# Refund booking transactions # Refund booking transactions
economy: Economy = self.bot.get_cog('Economy') economy: Economy = self.bot.get_cog('Economy')
maybe_tids = (r.book_transactionid for r in bookings) maybe_tids = (r.book_transactionid for r in bookings)
tids = [tid for tid in maybe_tids if tid is not None] tids = [tid for tid in maybe_tids if tid is not None]
await economy.data.Transaction.refund_transactions(*tids) await economy.data.Transaction.refund_transactions(*tids)
# Delete booking rows # Delete booking rows
await self.data.ScheduleSessionMember.table.delete_where( await self.data.ScheduleSessionMember.table.delete_where(
MEMBERS(*((r.guildid, r.userid) for r in bookings)), MEMBERS(*((r.guildid, r.userid) for r in bookings)),
slotid=self.slotid,
)
# Trigger message update for existent messages
lobby_tasks = [
asyncio.create_task(session.update_status(save=False, resend=False))
for session in sessions
]
await asyncio.gather(*lobby_tasks)
# Mark sessions as closed
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid, slotid=self.slotid,
guildid=[session.guildid for session in sessions]
).set(
closed_at=utc_now()
) )
# TODO: Logging
# Trigger message update for existent messages
lobby_tasks = [
asyncio.create_task(session.update_status(save=False, resend=False))
for session in sessions
]
await asyncio.gather(*lobby_tasks)
# Mark sessions as closed
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid,
guildid=[session.guildid for session in sessions]
).set(
closed_at=utc_now()
)
# TODO: Logging

View File

@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
from meta import LionCog, LionContext, LionBot from meta import LionCog, LionContext, LionBot
from meta.errors import SafeCancellation from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils import ui from utils import ui
from utils.lib import error_embed from utils.lib import error_embed
from constants import MAX_COINS from constants import MAX_COINS
@@ -145,6 +146,7 @@ class ColourShop(Shop):
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance) if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
] ]
@log_wrap(action='purchase')
async def purchase(self, itemid) -> ColourRoleItem: async def purchase(self, itemid) -> ColourRoleItem:
""" """
Atomically handle a purchase of a ColourRoleItem. Atomically handle a purchase of a ColourRoleItem.
@@ -157,144 +159,145 @@ class ColourShop(Shop):
If the purchase fails for a known reason, raises SafeCancellation, with the error information. If the purchase fails for a known reason, raises SafeCancellation, with the error information.
""" """
t = self.bot.translator.t t = self.bot.translator.t
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
# Retrieve the item to purchase from data async with conn.transaction():
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid) # Retrieve the item to purchase from data
# Ensure the item is purchasable and not deleted item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
if not item['purchasable'] or item['deleted']: # Ensure the item is purchasable and not deleted
raise SafeCancellation( if not item['purchasable'] or item['deleted']:
t(_p(
'shop:colour|purchase|error:not_purchasable',
"This item may not be purchased!"
))
)
# Refresh the customer
await self.customer.refresh()
# Ensure the guild exists in cache
guild = self.bot.get_guild(self.customer.guildid)
if guild is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_guild',
"Could not retrieve the server from Discord!"
))
)
# Ensure the customer member actually exists
member = await self.customer.lion.fetch_member()
if member is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_member',
"Could not retrieve the member from Discord."
))
)
# Ensure the purchased role actually exists
role = guild.get_role(item['roleid'])
if role is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_role',
"This colour role could not be found in the server."
))
)
# Ensure the customer has enough coins for the item
if self.customer.balance < item['price']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:low_balance',
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
)).format(
coin=self.bot.config.emojis.getemoji('coin'),
amount=item['price'],
balance=self.customer.balance
)
)
owned = self.owned()
if owned is not None:
# Ensure the customer does not already own the item
if owned.itemid == item['itemid']:
raise SafeCancellation( raise SafeCancellation(
t(_p( t(_p(
'shop:colour|purchase|error:owned', 'shop:colour|purchase|error:not_purchasable',
"You already own this item!" "This item may not be purchased!"
)) ))
) )
# Charge the customer for the item # Refresh the customer
economy_cog: Economy = self.bot.get_cog('Economy') await self.customer.refresh()
economy_data = economy_cog.data
transaction = await economy_data.ShopTransaction.purchase_transaction(
guild.id,
member.id,
member.id,
itemid,
item['price']
)
# Add the item to the customer's inventory # Ensure the guild exists in cache
await self.data.MemberInventory.create( guild = self.bot.get_guild(self.customer.guildid)
guildid=guild.id, if guild is None:
userid=member.id, raise SafeCancellation(
transactionid=transaction.transactionid, t(_p(
itemid=itemid 'shop:colour|purchase|error:no_guild',
) "Could not retrieve the server from Discord!"
))
)
# Give the customer the role (do rollback if this fails) # Ensure the customer member actually exists
try: member = await self.customer.lion.fetch_member()
await member.add_roles( if member is None:
role, raise SafeCancellation(
atomic=True, t(_p(
reason="Purchased colour role" 'shop:colour|purchase|error:no_member',
) "Could not retrieve the member from Discord."
except discord.NotFound: ))
raise SafeCancellation( )
t(_p(
'shop:colour|purchase|error:failed_no_role',
"This colour role no longer exists!"
))
)
except discord.Forbidden:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_permissions',
"I do not have enough permissions to give you this colour role!"
))
)
except discord.HTTPException:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_unknown',
"An unknown error occurred while giving you this colour role!"
))
)
# At this point, the purchase has succeeded and the user has obtained the colour role # Ensure the purchased role actually exists
# Now, remove their previous colour role (if applicable) role = guild.get_role(item['roleid'])
# TODO: We should probably add an on_role_delete event to clear defunct colour roles if role is None:
if owned is not None: raise SafeCancellation(
owned_role = owned.role t(_p(
if owned_role is not None: 'shop:colour|purchase|error:no_role',
try: "This colour role could not be found in the server."
await member.remove_roles( ))
owned_role, )
reason="Removing old colour role.",
atomic=True # Ensure the customer has enough coins for the item
if self.customer.balance < item['price']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:low_balance',
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
)).format(
coin=self.bot.config.emojis.getemoji('coin'),
amount=item['price'],
balance=self.customer.balance
) )
except discord.HTTPException: )
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
pass
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
# Purchase complete, update the shop and customer owned = self.owned()
await self.refresh() if owned is not None:
return self.owned() # Ensure the customer does not already own the item
if owned.itemid == item['itemid']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:owned',
"You already own this item!"
))
)
# Charge the customer for the item
economy_cog: Economy = self.bot.get_cog('Economy')
economy_data = economy_cog.data
transaction = await economy_data.ShopTransaction.purchase_transaction(
guild.id,
member.id,
member.id,
itemid,
item['price']
)
# Add the item to the customer's inventory
await self.data.MemberInventory.create(
guildid=guild.id,
userid=member.id,
transactionid=transaction.transactionid,
itemid=itemid
)
# Give the customer the role (do rollback if this fails)
try:
await member.add_roles(
role,
atomic=True,
reason="Purchased colour role"
)
except discord.NotFound:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_no_role',
"This colour role no longer exists!"
))
)
except discord.Forbidden:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_permissions',
"I do not have enough permissions to give you this colour role!"
))
)
except discord.HTTPException:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_unknown',
"An unknown error occurred while giving you this colour role!"
))
)
# At this point, the purchase has succeeded and the user has obtained the colour role
# Now, remove their previous colour role (if applicable)
# TODO: We should probably add an on_role_delete event to clear defunct colour roles
if owned is not None:
owned_role = owned.role
if owned_role is not None:
try:
await member.remove_roles(
owned_role,
reason="Removing old colour role.",
atomic=True
)
except discord.HTTPException:
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
pass
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
# Purchase complete, update the shop and customer
await self.refresh()
return self.owned()
async def refresh(self): async def refresh(self):
""" """

View File

@@ -4,6 +4,7 @@ from enum import Enum
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table, RegisterEnum from data import RowModel, Registry, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Bool, Column from data.columns import Integer, String, Timestamp, Bool, Column
@@ -80,6 +81,7 @@ class StatsData(Registry):
end_time = Timestamp() end_time = Timestamp()
@classmethod @classmethod
@log_wrap(action='tracked_time_between')
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]): async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
query = sql.SQL( query = sql.SQL(
""" """
@@ -103,25 +105,27 @@ class StatsData(Registry):
for _ in points for _ in points
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
chain(*points) chain(*points)
) )
return cursor.fetchall() return cursor.fetchall()
@classmethod @classmethod
@log_wrap(action='study_time_between')
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int: async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_between(%s, %s, %s, %s)", "SELECT study_time_between(%s, %s, %s, %s)",
(guildid, userid, _start, _end) (guildid, userid, _start, _end)
) )
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone()[0]) or 0
@classmethod @classmethod
@log_wrap(action='study_times_between')
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]: async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
if len(points) < 2: if len(points) < 2:
raise ValueError('Not enough block points given!') raise ValueError('Not enough block points given!')
@@ -141,25 +145,27 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((guildid, userid), *blocks)) tuple(chain((guildid, userid), *blocks))
) )
return [r['stime'] or 0 for r in await cursor.fetchall()] return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int: async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_since(%s, %s, %s)", "SELECT study_time_since(%s, %s, %s)",
(guildid, userid, _start) (guildid, userid, _start)
) )
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone()[0]) or 0
@classmethod @classmethod
@log_wrap(action='study_times_between')
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int: async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
if len(starts) < 1: if len(starts) < 1:
raise ValueError('No starting points given!') raise ValueError('No starting points given!')
@@ -178,15 +184,16 @@ class StatsData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in starts sql.SQL("({})").format(sql.Placeholder()) for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((guildid, userid), starts)) tuple(chain((guildid, userid), starts))
) )
return [r['stime'] or 0 for r in await cursor.fetchall()] return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since): async def leaderboard_since(cls, guildid: int, since):
""" """
Return the voice totals since the given time for each member in the guild. Return the voice totals since the given time for each member in the guild.
@@ -226,23 +233,25 @@ class StatsData(Registry):
) )
second_query_args = (since, guildid, since, since) second_query_args = (since, guildid, since, since)
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.transaction(): cls._connector.conn = conn
async with conn.cursor() as cursor: async with conn.transaction():
await cursor.execute(second_query, second_query_args) async with conn.cursor() as cursor:
overshoot_rows = await cursor.fetchall() await cursor.execute(second_query, second_query_args)
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows} overshoot_rows = await cursor.fetchall()
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows}
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(first_query, first_query_args) await cursor.execute(first_query, first_query_args)
leaderboard = [ leaderboard = [
(row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0))) (row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0)))
for row in await cursor.fetchall() for row in await cursor.fetchall()
] ]
leaderboard.sort(key=lambda t: t[1], reverse=True) leaderboard.sort(key=lambda t: t[1], reverse=True)
return leaderboard return leaderboard
@classmethod @classmethod
@log_wrap('leaderboard_all')
async def leaderboard_all(cls, guildid: int): async def leaderboard_all(cls, guildid: int):
""" """
Return the all-time voice totals for the given guild. Return the all-time voice totals for the given guild.
@@ -257,13 +266,13 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, )) await cursor.execute(query, (guildid, ))
leaderboard = [ leaderboard = [
(row['userid'], int(row['total_duration'])) (row['userid'], int(row['total_duration']))
for row in await cursor.fetchall() for row in await cursor.fetchall()
] ]
return leaderboard return leaderboard
class MemberExp(RowModel): class MemberExp(RowModel):
@@ -296,6 +305,7 @@ class StatsData(Registry):
transactionid = Integer() transactionid = Integer()
@classmethod @classmethod
@log_wrap(action='xp_since')
async def xp_since(cls, guildid: int, userid: int, *starts): async def xp_since(cls, guildid: int, userid: int, *starts):
query = sql.SQL( query = sql.SQL(
""" """
@@ -320,15 +330,16 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts sql.Placeholder() for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((guildid, userid), starts)) tuple(chain((guildid, userid), starts))
) )
return [r['exp'] or 0 for r in await cursor.fetchall()] return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='xp_between')
async def xp_between(cls, guildid: int, userid: int, *points): async def xp_between(cls, guildid: int, userid: int, *points):
blocks = zip(points, points[1:]) blocks = zip(points, points[1:])
query = sql.SQL( query = sql.SQL(
@@ -355,15 +366,16 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((guildid, userid), *blocks)) tuple(chain((guildid, userid), *blocks))
) )
return [r['period_xp'] or 0 for r in await cursor.fetchall()] return [r['period_xp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since): async def leaderboard_since(cls, guildid: int, since):
""" """
Return the XP totals for the given guild since the given time. Return the XP totals for the given guild since the given time.
@@ -378,16 +390,17 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since)) await cursor.execute(query, (guildid, since))
leaderboard = [ leaderboard = [
(row['userid'], int(row['total_xp'])) (row['userid'], int(row['total_xp']))
for row in await cursor.fetchall() for row in await cursor.fetchall()
] ]
return leaderboard return leaderboard
@classmethod @classmethod
@log_wrap(action='leaderboard_all')
async def leaderboard_all(cls, guildid: int): async def leaderboard_all(cls, guildid: int):
""" """
Return the all-time XP totals for the given guild. Return the all-time XP totals for the given guild.
@@ -402,13 +415,13 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, )) await cursor.execute(query, (guildid, ))
leaderboard = [ leaderboard = [
(row['userid'], int(row['total_xp'])) (row['userid'], int(row['total_xp']))
for row in await cursor.fetchall() for row in await cursor.fetchall()
] ]
return leaderboard return leaderboard
class UserExp(RowModel): class UserExp(RowModel):
@@ -436,6 +449,7 @@ class StatsData(Registry):
exp_type: Column[ExpType] = Column() exp_type: Column[ExpType] = Column()
@classmethod @classmethod
@log_wrap(action='user_xp_since')
async def xp_since(cls, userid: int, *starts): async def xp_since(cls, userid: int, *starts):
query = sql.SQL( query = sql.SQL(
""" """
@@ -459,15 +473,16 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts sql.Placeholder() for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((userid,), starts)) tuple(chain((userid,), starts))
) )
return [r['exp'] or 0 for r in await cursor.fetchall()] return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='user_xp_since')
async def xp_between(cls, userid: int, *points): async def xp_between(cls, userid: int, *points):
blocks = zip(points, points[1:]) blocks = zip(points, points[1:])
query = sql.SQL( query = sql.SQL(
@@ -493,13 +508,13 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((userid,), *blocks)) tuple(chain((userid,), *blocks))
) )
return [r['period_xp'] or 0 for r in await cursor.fetchall()] return [r['period_xp'] or 0 for r in await cursor.fetchall()]
class ProfileTag(RowModel): class ProfileTag(RowModel):
""" """
@@ -531,15 +546,17 @@ class StatsData(Registry):
return [tag.tag for tag in tags] return [tag.tag for tag in tags]
@classmethod @classmethod
@log_wrap(action='set_profile_tags')
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]): async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
conn = await self._connector.get_connection() async with self._connector.connection() as conn:
async with conn.transaction(): self._connector.conn = conn
await self.table.delete_where(guildid=guildid, userid=userid) async with conn.transaction():
if tags: await self.table.delete_where(guildid=guildid, userid=userid)
await self.table.insert_many( if tags:
('guildid', 'userid', 'tag'), await self.table.insert_many(
*((guildid, userid, tag) for tag in tags) ('guildid', 'userid', 'tag'),
) *((guildid, userid, tag) for tag in tags)
)
class WeeklyGoals(RowModel): class WeeklyGoals(RowModel):
""" """

View File

@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
# Update the tasklist # Update the tasklist
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)): if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
modified = True modified = True
conn = await self.bot.db.get_connection() async with self._connector.connection() as conn:
async with conn.transaction(): async with conn.transaction():
await tasks_model.table.delete_where(**key) await tasks_model.table.delete_where(**key).with_connection(conn)
if new_tasks: if new_tasks:
await tasks_model.table.insert_many( await tasks_model.table.insert_many(
(*key.keys(), 'completed', 'content'), (*key.keys(), 'completed', 'content'),
*((*key.values(), *new_task) for new_task in new_tasks) *((*key.values(), *new_task) for new_task in new_tasks)
) ).with_connection(conn)
if modified: if modified:
# If either goal type was modified, clear the rendered cache and refresh # If either goal type was modified, clear the rendered cache and refresh

View File

@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from meta import LionBot, LionCog, LionContext from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import UserInputError from meta.errors import UserInputError
from utils.lib import utc_now, error_embed from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton from utils.ui import ChoicedEnum, Transformed, AButton
@@ -141,30 +142,32 @@ class TasklistCog(LionCog):
self.crossload_group(self.configure_group, configcog.configure_group) self.crossload_group(self.configure_group, configcog.configure_group)
@LionCog.listener('on_tasks_completed') @LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int): async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
tasklist = await Tasklist.fetch(self.bot, self.data, member.id) async with conn.transaction():
tasks = await tasklist.fetch_tasks(*taskids) tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
unrewarded = [task for task in tasks if not task.rewarded] tasks = await tasklist.fetch_tasks(*taskids)
if unrewarded: unrewarded = [task for task in tasks if not task.rewarded]
reward = (await self.settings.task_reward.get(member.guild.id)).value if unrewarded:
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value reward = (await self.settings.task_reward.get(member.guild.id)).value
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value
ecog = self.bot.get_cog('Economy') ecog = self.bot.get_cog('Economy')
recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0 recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0
max_to_reward = limit - recent max_to_reward = limit - recent
if max_to_reward > 0: if max_to_reward > 0:
to_reward = unrewarded[:max_to_reward] to_reward = unrewarded[:max_to_reward]
count = len(to_reward) count = len(to_reward)
amount = count * reward amount = count * reward
await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount) await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount)
await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True) await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True)
logger.debug( logger.debug(
f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> " f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> "
f"'{amount}' coins for completing '{count}' tasks." f"'{amount}' coins for completing '{count}' tasks."
) )
async def is_tasklist_channel(self, channel) -> bool: async def is_tasklist_channel(self, channel) -> bool:
if not channel.guild: if not channel.guild:
@@ -477,43 +480,40 @@ class TasklistCog(LionCog):
# Contents successfully parsed, update the tasklist. # Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
# Lazily using the editor because it has a good parser
taskinfo = tasklist.parse_tasklist(lines) taskinfo = tasklist.parse_tasklist(lines)
conn = await self.bot.db.get_connection() now = utc_now()
async with conn.transaction():
now = utc_now()
# Delete tasklist if required # Delete tasklist if required
if not append: if not append:
await tasklist.update_tasklist(deleted_at=now) await tasklist.update_tasklist(deleted_at=now)
# Create tasklist # Create tasklist
# TODO: Refactor into common method with parse tasklist # TODO: Refactor into common method with parse tasklist
created = {} created = {}
target_depth = 0 target_depth = 0
while True: while True:
to_insert = {} to_insert = {}
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo): for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
if truedepth == target_depth: if truedepth == target_depth:
to_insert[i] = ( to_insert[i] = (
tasklist.userid, tasklist.userid,
content, content,
created[parent] if parent is not None else None, created[parent] if parent is not None else None,
now if ticked else None now if ticked else None
)
if to_insert:
# Batch insert
tasks = await tasklist.data.Task.table.insert_many(
('userid', 'content', 'parentid', 'completed_at'),
*to_insert.values()
) )
for i, task in zip(to_insert.keys(), tasks): if to_insert:
created[i] = task['taskid'] # Batch insert
target_depth += 1 tasks = await tasklist.data.Task.table.insert_many(
else: ('userid', 'content', 'parentid', 'completed_at'),
# Reached maximum depth *to_insert.values()
break )
for i, task in zip(to_insert.keys(), tasks):
created[i] = task['taskid']
target_depth += 1
else:
# Reached maximum depth
break
# Ack modifications # Ack modifications
embed = discord.Embed( embed = discord.Embed(

View File

@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
from discord.ui.text_input import TextInput, TextStyle from discord.ui.text_input import TextInput, TextStyle
from meta import conf from meta import conf
from meta.logger import log_wrap
from meta.errors import UserInputError from meta.errors import UserInputError
from utils.lib import MessageArgs, utc_now from utils.lib import MessageArgs, utc_now
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
except UserInputError as error: except UserInputError as error:
await ModalRetryUI(self, error.msg).respond_to(interaction) await ModalRetryUI(self, error.msg).respond_to(interaction)
@log_wrap(action="parse editor")
async def parse_editor(self): async def parse_editor(self):
# First parse each line # First parse each line
new_lines = self.tasklist_editor.value.splitlines() new_lines = self.tasklist_editor.value.splitlines()
@@ -155,27 +157,28 @@ class BulkEditor(LeoModal):
) )
# TODO: Incremental/diff editing # TODO: Incremental/diff editing
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
async with conn.transaction(): self.bot.db.conn = conn
now = utc_now() async with conn.transaction():
now = utc_now()
if same_layout: if same_layout:
# if the layout has not changed, just edit the tasks # if the layout has not changed, just edit the tasks
for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)): for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)):
args = {} args = {}
if oldinfo[2] != newinfo[2]: if oldinfo[2] != newinfo[2]:
args['completed_at'] = now if newinfo[2] else None args['completed_at'] = now if newinfo[2] else None
if oldinfo[3] != newinfo[3]: if oldinfo[3] != newinfo[3]:
args['content'] = newinfo[3] args['content'] = newinfo[3]
if args: if args:
await self.tasklist.update_tasks(taskid, **args) await self.tasklist.update_tasks(taskid, **args)
else: else:
# Naive implementation clearing entire tasklist # Naive implementation clearing entire tasklist
# Clear tasklist # Clear tasklist
await self.tasklist.update_tasklist(deleted_at=now) await self.tasklist.update_tasklist(deleted_at=now)
# Create tasklist # Create tasklist
await self.tasklist.write_taskinfo(taskinfo) await self.tasklist.write_taskinfo(taskinfo)
class UIMode(Enum): class UIMode(Enum):

View File

@@ -2,6 +2,7 @@ from typing import Type
import json import json
from data import RowModel, Table, ORDER from data import RowModel, Table, ORDER
from meta.logger import log_wrap, set_logging_context
class ModelData: class ModelData:
@@ -60,6 +61,7 @@ class ModelData:
It only updates. It only updates.
""" """
# TODO: Better way of getting the key? # TODO: Better way of getting the key?
# TODO: Transaction
if not isinstance(parent_id, tuple): if not isinstance(parent_id, tuple):
parent_id = (parent_id, ) parent_id = (parent_id, )
model = cls._model model = cls._model
@@ -83,6 +85,8 @@ class ListData:
This assumes the list is the only data stored in the table, This assumes the list is the only data stored in the table,
and removes list entries by deleting rows. and removes list entries by deleting rows.
""" """
setting_id: str
# Table storing the setting data # Table storing the setting data
_table_interface: Table _table_interface: Table
@@ -100,10 +104,12 @@ class ListData:
_cache = None # Map[id -> value] _cache = None # Map[id -> value]
@classmethod @classmethod
@log_wrap(isolate=True)
async def _reader(cls, parent_id, use_cache=True, **kwargs): async def _reader(cls, parent_id, use_cache=True, **kwargs):
""" """
Read in all entries associated to the given id. Read in all entries associated to the given id.
""" """
set_logging_context(action="Read cls.setting_id")
if cls._cache is not None and parent_id in cls._cache and use_cache: if cls._cache is not None and parent_id in cls._cache and use_cache:
return cls._cache[parent_id] return cls._cache[parent_id]
@@ -121,53 +127,56 @@ class ListData:
return data return data
@classmethod @classmethod
@log_wrap(isolate=True)
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs): async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
""" """
Write the provided list to storage. Write the provided list to storage.
""" """
set_logging_context(action="Write cls.setting_id")
table = cls._table_interface table = cls._table_interface
conn = await table.connector.get_connection() async with table.connector.connection() as conn:
async with conn.transaction(): table.connector.conn = conn
# Handle None input as an empty list async with conn.transaction():
if data is None: # Handle None input as an empty list
data = [] if data is None:
data = []
current = await cls._reader(id, use_cache=False, **kwargs) current = await cls._reader(id, use_cache=False, **kwargs)
if not cls._order_column and (add_only or remove_only): if not cls._order_column and (add_only or remove_only):
to_insert = [item for item in data if item not in current] if not remove_only else [] to_insert = [item for item in data if item not in current] if not remove_only else []
to_remove = data if remove_only else ( to_remove = data if remove_only else (
[item for item in current if item not in data] if not add_only else [] [item for item in current if item not in data] if not add_only else []
) )
# Handle required deletions # Handle required deletions
if to_remove: if to_remove:
params = { params = {
cls._id_column: id, cls._id_column: id,
cls._data_column: to_remove cls._data_column: to_remove
} }
await table.delete_where(**params) await table.delete_where(**params)
# Handle required insertions # Handle required insertions
if to_insert: if to_insert:
columns = (cls._id_column, cls._data_column) columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in to_insert] values = [(id, value) for value in to_insert]
await table.insert_many(columns, *values) await table.insert_many(columns, *values)
if cls._cache is not None: if cls._cache is not None:
new_current = [item for item in current + to_insert if item not in to_remove] new_current = [item for item in current + to_insert if item not in to_remove]
cls._cache[id] = new_current cls._cache[id] = new_current
else: else:
# Remove all and add all to preserve order # Remove all and add all to preserve order
delete_params = {cls._id_column: id} delete_params = {cls._id_column: id}
await table.delete_where(**delete_params) await table.delete_where(**delete_params)
if data: if data:
columns = (cls._id_column, cls._data_column) columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in data] values = [(id, value) for value in data]
await table.insert_many(columns, *values) await table.insert_many(columns, *values)
if cls._cache is not None: if cls._cache is not None:
cls._cache[id] = data cls._cache[id] = data
class KeyValueData: class KeyValueData:

View File

@@ -1,7 +1,7 @@
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool from data.columns import Integer, String, Timestamp, Bool
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
member_expid = Integer() member_expid = Integer()
@classmethod @classmethod
@log_wrap(action='end_text_sessions')
async def end_sessions(cls, connector, *session_data): async def end_sessions(cls, connector, *session_data):
query = sql.SQL(""" query = sql.SQL("""
WITH WITH
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
) SELECT ) SELECT
data._guildid, 0, data._guildid, 0,
NULL, data._userid, NULL, data._userid,
SUM(_coins), 0, 'TEXT_SESSION' LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
FROM data FROM data
WHERE data._coins > 0 WHERE data._coins > 0
GROUP BY (data._guildid, data._userid) GROUP BY (data._guildid, data._userid)
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
) )
, member AS ( , member AS (
UPDATE members UPDATE members
SET coins = coins + data._coins SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
FROM data FROM data
WHERE members.userid = data._userid AND members.guildid = data._guildid WHERE members.userid = data._userid AND members.guildid = data._guildid
) )
@@ -166,15 +167,16 @@ class TextTrackerData(Registry):
# Or ask for a connection from the connection pool # Or ask for a connection from the connection pool
# Transaction may take some time due to index updates # Transaction may take some time due to index updates
# Alternatively maybe use the "do not expect response mode" # Alternatively maybe use the "do not expect response mode"
conn = await connector.get_connection() async with connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain(*session_data)) tuple(chain(*session_data))
) )
return return
@classmethod @classmethod
@log_wrap(action='user_messages_between')
async def user_messages_between(cls, userid: int, *points): async def user_messages_between(cls, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -203,15 +205,16 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((userid,), *blocks)) tuple(chain((userid,), *blocks))
) )
return [r['period_m'] or 0 for r in await cursor.fetchall()] return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='member_messages_between')
async def member_messages_between(cls, guildid: int, userid: int, *points): async def member_messages_between(cls, guildid: int, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -241,15 +244,16 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((userid, guildid), *blocks)) tuple(chain((userid, guildid), *blocks))
) )
return [r['period_m'] or 0 for r in await cursor.fetchall()] return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='member_messages_since')
async def member_messages_since(cls, guildid: int, userid: int, *points): async def member_messages_since(cls, guildid: int, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -277,12 +281,12 @@ class TextTrackerData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in points sql.SQL("({})").format(sql.Placeholder()) for _ in points
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain((userid, guildid), points)) tuple(chain((userid, guildid), points))
) )
return [r['messages'] or 0 for r in await cursor.fetchall()] return [r['messages'] or 0 for r in await cursor.fetchall()]
untracked_channels = Table('untracked_text_channels') untracked_channels = Table('untracked_text_channels')

View File

@@ -2,6 +2,7 @@ import datetime as dt
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool from data.columns import Integer, String, Timestamp, Bool
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
hourly_coins = Integer() hourly_coins = Integer()
@classmethod @classmethod
@log_wrap(action='close_voice_session')
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int: async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT close_study_session_at(%s, %s, %s)", "SELECT close_study_session_at(%s, %s, %s)",
(guildid, userid, _at) (guildid, userid, _at)
) )
member_data = await cursor.fetchone() return await cursor.fetchone()
@classmethod @classmethod
@log_wrap(action='close_voice_sessions')
async def close_voice_sessions_at(cls, *arg_tuples): async def close_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
SELECT SELECT
@@ -139,28 +142,30 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain(*arg_tuples)) tuple(chain(*arg_tuples))
) )
@classmethod @classmethod
@log_wrap(action='update_voice_session')
async def update_voice_session_at( async def update_voice_session_at(
cls, guildid: int, userid: int, _at: dt.datetime, cls, guildid: int, userid: int, _at: dt.datetime,
stream: bool, video: bool, rate: float stream: bool, video: bool, rate: float
) -> int: ) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)", "SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
(guildid, userid, _at, stream, video, rate) (guildid, userid, _at, stream, video, rate)
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return cls._make_rows(*rows) return cls._make_rows(*rows)
@classmethod @classmethod
@log_wrap(action='update_voice_sessions')
async def update_voice_sessions_at(cls, *arg_tuples): async def update_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
UPDATE UPDATE
@@ -209,14 +214,14 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain(*arg_tuples)) tuple(chain(*arg_tuples))
) )
rows = await cursor.fetchall() rows = await cursor.fetchall()
return cls._make_rows(*rows) return cls._make_rows(*rows)
class VoiceSessions(RowModel): class VoiceSessions(RowModel):
""" """
@@ -257,17 +262,19 @@ class VoiceTrackerData(Registry):
transactionid = Integer() transactionid = Integer()
@classmethod @classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int: async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_since(%s, %s, %s) AS result", "SELECT study_time_since(%s, %s, %s) AS result",
(guildid, userid, _start) (guildid, userid, _start)
) )
result = await cursor.fetchone() result = await cursor.fetchone()
return (result['result'] or 0) if result else 0 return (result['result'] or 0) if result else 0
@classmethod @classmethod
@log_wrap(action='multiple_voice_tracked_since')
async def multiple_voice_tracked_since(cls, *arg_tuples): async def multiple_voice_tracked_since(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
SELECT SELECT
@@ -286,13 +293,13 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
tuple(chain(*arg_tuples)) tuple(chain(*arg_tuples))
) )
return await cursor.fetchall() return await cursor.fetchall()
""" """
Schema Schema

View File

@@ -161,7 +161,6 @@ class VoiceSession:
self.state.channelid, guildid=self.guildid, deleted=False self.state.channelid, guildid=self.guildid, deleted=False
) )
conn = await self.bot.db.get_connection()
# Insert an ongoing_session with the correct state, set data # Insert an ongoing_session with the correct state, set data
state = self.state state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create( self.data = await self.registry.VoiceSessionsOngoing.create(