fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
f"Queue on event handler {self.route_name} is full! Discarding event {data}" f"Queue on event handler {self.route_name} is full! Discarding event {data}"
) )
@log_wrap(action='consumer', isolate=False) @log_wrap(action='consumer')
async def consumer(self): async def consumer(self):
while True: while True:
try: try:
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
) )
pass pass
@log_wrap(action='batch', isolate=False) @log_wrap(action='batch')
async def process_batch(self): async def process_batch(self):
logger.debug("Processing Batch") logger.debug("Processing Batch")
# TODO: copy syntax might be more efficient here # TODO: copy syntax might be more efficient here

View File

@@ -123,7 +123,7 @@ class AnalyticsServer:
log_action_stack.set(['Analytics']) log_action_stack.set(['Analytics'])
log_app.set(conf.analytics['appname']) log_app.set(conf.analytics['appname'])
async with await self.db.connect(): async with self.db.open():
await self.talk.connect() await self.talk.connect()
await self.attach_event_handlers() await self.attach_event_handlers()
self._snap_task = asyncio.create_task(self.snapshot_loop()) self._snap_task = asyncio.create_task(self.snapshot_loop())

View File

@@ -38,7 +38,7 @@ async def main():
intents.message_content = True intents.message_content = True
intents.presences = False intents.presences = False
async with await db.connect(): async with db.open():
version = await db.version() version = await db.version()
if version.version != DATA_VERSION: if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate." error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."

View File

@@ -57,8 +57,6 @@ class CoreCog(LionCog):
async def cog_load(self): async def cog_load(self):
# Fetch (and possibly create) core data rows. # Fetch (and possibly create) core data rows.
conn = await self.bot.db.get_connection()
async with conn.transaction():
self.app_config = await self.data.AppConfig.fetch_or_create(appname) self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname) self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.shard_data = await self.data.Shard.fetch_or_create( self.shard_data = await self.data.Shard.fetch_or_create(

View File

@@ -5,6 +5,7 @@ from cachetools import TTLCache
import discord import discord
from meta import conf from meta import conf
from meta.logger import log_wrap
from data import Table, Registry, Column, RowModel, RegisterEnum from data import Table, Registry, Column, RowModel, RegisterEnum
from data.models import WeakCache from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp from data.columns import Integer, String, Bool, Timestamp
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
_timestamp = Timestamp() _timestamp = Timestamp()
@classmethod @classmethod
@log_wrap(action="Add Pending Coins")
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']: async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
""" """
Safely add pending coins to a list of members. Safely add pending coins to a list of members.
@@ -316,7 +318,7 @@ class CoreData(Registry, name="core"):
) )
) )
# TODO: Replace with copy syntax/query? # TODO: Replace with copy syntax/query?
conn = await cls.table.connector.get_connection() async with cls.table.connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -326,11 +328,12 @@ class CoreData(Registry, name="core"):
return cls._make_rows(*rows) return cls._make_rows(*rows)
@classmethod @classmethod
@log_wrap(action='get_member_rank')
async def get_member_rank(cls, guildid, userid, untracked): async def get_member_rank(cls, guildid, userid, untracked):
""" """
Get the time and coin ranking for the given member, ignoring the provided untracked members. Get the time and coin ranking for the given member, ignoring the provided untracked members.
""" """
conn = await cls.table.connector.get_connection() async with cls.table.connector.connection() as conn:
async with conn.cursor() as curs: async with conn.cursor() as curs:
await curs.execute( await curs.execute(
""" """

View File

@@ -1,6 +1,7 @@
# from enum import Enum # from enum import Enum
from typing import Optional from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry from .registry import Attachable, Registry
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
connector = registry._conn connector = registry._conn
if connector is None: if connector is None:
raise ValueError("Cannot initialise without connector!") raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection() connector.connect_hook(self.connection_hook)
if connection is None: # await connector.refresh_pool()
raise ValueError("Cannot Init without connection.") # The below may be somewhat dangerous
info = await EnumInfo.fetch(connection, self.name) # But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None: if info is None:
raise ValueError(f"Enum {self.name} not found in database.") raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items())) register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector: class Connector:
cursor_factory = AsyncLoggingCursor cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args): def __init__(self, conn_args):
self._conn_args = conn_args self._conn_args = conn_args
self.conn: psq.AsyncConnection = None self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = [] self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection: @property
def conn(self) -> Optional[psq.AsyncConnection]:
""" """
Get the current active connection. Convenience property for the current context connection.
This should never be cached outside of a transaction.
""" """
# TODO: Reconnection logic? return ctx_connection.get()
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
async def connect(self) -> psq.AsyncConnection: @conn.setter
logger.info("Establishing connection to database.", extra={'action': "Data Connect"}) def conn(self, conn: psq.AsyncConnection):
self.conn = await psq.AsyncConnection.connect( """
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
) )
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks: for hook in self.conn_hooks:
await hook(self.conn) try:
return self.conn await hook(conn)
except Exception:
async def reconnect(self) -> psq.AsyncConnection: logger.exception("Exception encountered setting up new connection")
return await self.connect() return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]): def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
""" """

View File

@@ -35,7 +35,8 @@ class Database(Connector):
""" """
Return the current schema version as a Version namedtuple. Return the current schema version as a Version namedtuple.
""" """
async with self.conn.cursor() as cursor: async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version # Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1") await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone() row = await cursor.fetchone()

View File

@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
if self.connector is None: if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.") raise ValueError("Cannot execute query without cursor, connection, or connector.")
else: else:
conn = await self.connector.get_connection() async with self.connector.connection() as conn:
else:
conn = self.conn
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
data = await self._execute(cursor) data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else: else:
data = await self._execute(cursor) data = await self._execute(cursor)
return data return data

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union from typing import Optional, Union
import asyncio
import discord import discord
from discord.ext import commands as cmds from discord.ext import commands as cmds
@@ -182,7 +183,9 @@ class Economy(LionCog):
# We may need to do a mass row create operation. # We may need to do a mass row create operation.
targetids = set(target.id for target in targets) targetids = set(target.id for target in targets)
if len(targets) > 1: if len(targets) > 1:
conn = await ctx.bot.db.get_connection() async def wrapper():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
# First fetch the members which currently exist # First fetch the members which currently exist
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id) query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
@@ -206,6 +209,8 @@ class Economy(LionCog):
('guildid', 'userid', 'coins'), ('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids) *((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True) ).on_conflict(ignore=True)
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
await task
else: else:
# With only one target, we can take a simpler path, and make better use of local caches. # With only one target, we can take a simpler path, and make better use of local caches.
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id) await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
@@ -703,7 +708,9 @@ class Economy(LionCog):
# Alternative flow could be waiting until the target user presses accept # Alternative flow could be waiting until the target user presses accept
await ctx.interaction.response.defer(thinking=True, ephemeral=True) await ctx.interaction.response.defer(thinking=True, ephemeral=True)
conn = await self.bot.db.get_connection() async def wrapped():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
# We do this in a transaction so that if something goes wrong, # We do this in a transaction so that if something goes wrong,
# the coins deduction is rolled back atomicly # the coins deduction is rolled back atomicly
@@ -728,6 +735,7 @@ class Economy(LionCog):
await target_lion.data.update(coins=(Member.coins + amount)) await target_lion.data.update(coins=(Member.coins + amount))
# TODO: Audit trail # TODO: Audit trail
await asyncio.create_task(wrapped(), name="wrapped-send")
# Message target # Message target
embed = discord.Embed( embed = discord.Embed(

View File

@@ -1,6 +1,7 @@
from enum import Enum from enum import Enum
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
from data.columns import Integer, Bool, Column, Timestamp from data.columns import Integer, Bool, Column, Timestamp
from core.data import CoreData from core.data import CoreData
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
created_at = Timestamp() created_at = Timestamp()
@classmethod @classmethod
@log_wrap(action='execute_transaction')
async def execute_transaction( async def execute_transaction(
cls, cls,
transaction_type: TransactionType, transaction_type: TransactionType,
@@ -108,7 +110,8 @@ class EconomyData(Registry, name='economy'):
from_account: int, to_account: int, amount: int, bonus: int = 0, from_account: int, to_account: int, amount: int, bonus: int = 0,
refunds: int = None refunds: int = None
): ):
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
transaction = await cls.create( transaction = await cls.create(
transactiontype=transaction_type, transactiontype=transaction_type,
@@ -127,6 +130,7 @@ class EconomyData(Registry, name='economy'):
return transaction return transaction
@classmethod @classmethod
@log_wrap(action='execute_transactions')
async def execute_transactions(cls, *transactions): async def execute_transactions(cls, *transactions):
""" """
Execute multiple transactions in one data transaction. Execute multiple transactions in one data transaction.
@@ -142,7 +146,8 @@ class EconomyData(Registry, name='economy'):
if not transactions: if not transactions:
return [] return []
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
# Create the transactions # Create the transactions
rows = await cls.table.insert_many( rows = await cls.table.insert_many(
@@ -180,10 +185,12 @@ class EconomyData(Registry, name='economy'):
return rows return rows
@classmethod @classmethod
@log_wrap(action='refund_transactions')
async def refund_transactions(cls, *transactionids, actorid=0): async def refund_transactions(cls, *transactionids, actorid=0):
if not transactionids: if not transactionids:
return [] return []
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
# First fetch the transaction rows to refund # First fetch the transaction rows to refund
data = await cls.table.select_where(transactionid=transactionids) data = await cls.table.select_where(transactionid=transactionids)
@@ -217,12 +224,14 @@ class EconomyData(Registry, name='economy'):
itemid = Integer() itemid = Integer()
@classmethod @classmethod
@log_wrap(action='purchase_transaction')
async def purchase_transaction( async def purchase_transaction(
cls, cls,
guildid: int, actorid: int, guildid: int, actorid: int,
userid: int, itemid: int, amount: int userid: int, itemid: int, amount: int
): ):
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction( row = await EconomyData.Transaction.execute_transaction(
TransactionType.SHOP_PURCHASE, TransactionType.SHOP_PURCHASE,
@@ -263,12 +272,14 @@ class EconomyData(Registry, name='economy'):
return result[0]['recent'] or 0 return result[0]['recent'] or 0
@classmethod @classmethod
@log_wrap(action='reward_completed_tasks')
async def reward_completed(cls, userid, guildid, count, amount): async def reward_completed(cls, userid, guildid, count, amount):
""" """
Reward the specified member `amount` coins for completing `count` tasks. Reward the specified member `amount` coins for completing `count` tasks.
""" """
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog? # TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction( row = await EconomyData.Transaction.execute_transaction(
TransactionType.TASKS, TransactionType.TASKS,

View File

@@ -193,7 +193,8 @@ class MemberAdminCog(LionCog):
await lion.data.update(last_left=utc_now()) await lion.data.update(last_left=utc_now())
# Save member roles # Save member roles
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
await self.data.past_roles.delete_where( await self.data.past_roles.delete_where(
guildid=member.guild.id, guildid=member.guild.id,

View File

@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
update_args[instance._column] = instance.data update_args[instance._column] = instance.data
ack_lines.append(instance.update_message) ack_lines.append(instance.update_message)
await ctx.lguild.data.update(**update_args)
# Do the ack # Do the ack
tick = self.bot.config.emojis.tick tick = self.bot.config.emojis.tick
embed = discord.Embed( embed = discord.Embed(

View File

@@ -483,8 +483,6 @@ class RoleMenu:
)).format(role=role.name) )).format(role=role.name)
) )
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Remove the role # Remove the role
try: try:
await member.remove_roles(role) await member.remove_roles(role)
@@ -591,8 +589,6 @@ class RoleMenu:
) )
) )
conn = await self.bot.db.get_connection()
async with conn.transaction():
try: try:
await member.add_roles(role) await member.add_roles(role)
except discord.Forbidden: except discord.Forbidden:

View File

@@ -259,17 +259,6 @@ class RoomCog(LionCog):
lguild, lguild,
[member.id for member in members] [member.id for member in members]
) )
self._start(room)
# Send tips message
# TODO: Actual tips.
await channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
await ui.send(channel)
except Exception: except Exception:
try: try:
await channel.delete(reason="Failed to created private room") await channel.delete(reason="Failed to created private room")
@@ -454,7 +443,42 @@ class RoomCog(LionCog):
return return
# Positive response. Start a transaction. # Positive response. Start a transaction.
conn = await self.bot.db.get_connection() room = await self._do_create_room(ctx, required, days, rent, name, provided)
if room:
# Ack with confirmation message pointing to the room
msg = t(_p(
'cmd:room_rent|success',
"Successfully created your private room {channel}!"
)).format(channel=room.channel.mention)
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
description=msg
)
)
self._start(room)
# Send tips message
# TODO: Actual tips.
await room.channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
mention=ctx.author.mention
)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
await ui.send(room.channel)
@log_wrap(action='create_room')
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
t = self.bot.translator.t
# TODO: Rollback the channel create if this fails
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
# Note that the room creation will go into the UI as well.
async with conn.transaction(): async with conn.transaction():
# Check member balance is sufficient # Check member balance is sufficient
await ctx.alion.data.refresh() await ctx.alion.data.refresh()
@@ -486,7 +510,7 @@ class RoomCog(LionCog):
# Create room with given starting balance and other parameters # Create room with given starting balance and other parameters
try: try:
room = await self.create_private_room( return await self.create_private_room(
ctx.guild, ctx.guild,
ctx.author, ctx.author,
required - rent, required - rent,
@@ -519,19 +543,6 @@ class RoomCog(LionCog):
await ctx.alion.data.update(coins=CoreData.Member.coins + required) await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return return
# Ack with confirmation message pointing to the room
msg = t(_p(
'cmd:room_rent|success',
"Successfully created your private room {channel}!"
)).format(channel=room.channel.mention)
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
description=msg
)
)
@room_group.command( @room_group.command(
name=_p('cmd:room_status', "status"), name=_p('cmd:room_status', "status"),
description=_p( description=_p(
@@ -864,8 +875,7 @@ class RoomCog(LionCog):
return return
# Start Transaction # Start Transaction
conn = await self.bot.db.get_connection() # TODO: Economy transaction
async with conn.transaction():
await ctx.alion.data.refresh() await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins member_balance = ctx.alion.data.coins
if member_balance < coins: if member_balance < coins:
@@ -883,7 +893,6 @@ class RoomCog(LionCog):
return return
# Deduct balance # Deduct balance
# TODO: Economy transaction
await ctx.alion.data.update(coins=CoreData.Member.coins - coins) await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins) await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)

View File

@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
from meta import LionBot, conf from meta import LionBot, conf
from meta.errors import UserInputError from meta.errors import UserInputError
from meta.logger import log_wrap
from babel.translator import ctx_locale from babel.translator import ctx_locale
from utils.lib import utc_now, MessageArgs, error_embed from utils.lib import utc_now, MessageArgs, error_embed
from utils.ui import MessageUI, input from utils.ui import MessageUI, input
@@ -115,8 +116,18 @@ class RoomUI(MessageUI):
return return
await submit.response.defer(thinking=True, ephemeral=True) await submit.response.defer(thinking=True, ephemeral=True)
await self._do_deposit(t, press, amount, submit)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
@log_wrap(isolate=True)
async def _do_deposit(self, t, press, amount, submit):
# Start transaction for deposit # Start transaction for deposit
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
# Get the lion balance directly # Get the lion balance directly
lion = await self.bot.core.data.Member.fetch( lion = await self.bot.core.data.Member.fetch(
@@ -143,11 +154,6 @@ class RoomUI(MessageUI):
await lion.update(coins=CoreData.Member.coins - amount) await lion.update(coins=CoreData.Member.coins - amount)
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount) await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
async def desposit_button_refresh(self): async def desposit_button_refresh(self):
self.desposit_button.label = self.bot.translator.t(_p( self.desposit_button.label = self.bot.translator.t(_p(
'ui:room_status|button:deposit|label', 'ui:room_status|button:deposit|label',

View File

@@ -217,7 +217,8 @@ class ScheduleCog(LionCog):
for bookingid in bookingids: for bookingid in bookingids:
await self._cancel_booking_active(*bookingid) await self._cancel_booking_active(*bookingid)
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
# Now delete from data # Now delete from data
records = await self.data.ScheduleSessionMember.table.delete_where( records = await self.data.ScheduleSessionMember.table.delete_where(
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
"One or more requested timeslots are already booked!" "One or more requested timeslots are already booked!"
)) ))
raise UserInputError(error) raise UserInputError(error)
conn = await self.bot.db.get_connection()
# Booking request is now validated. Perform bookings. # Booking request is now validated. Perform bookings.
# Fetch or create session data # Fetch or create session data
@@ -482,8 +482,8 @@ class ScheduleCog(LionCog):
*((guildid, slotid) for slotid in slotids) *((guildid, slotid) for slotid in slotids)
) )
async with conn.transaction():
# Create transactions # Create transactions
# TODO: wrap in a transaction so the economy transaction gets unwound if it fails
economy = self.bot.get_cog('Economy') economy = self.bot.get_cog('Economy')
trans_data = ( trans_data = (
TransactionType.SCHEDULE_BOOK, TransactionType.SCHEDULE_BOOK,

View File

@@ -356,8 +356,7 @@ class TimeSlot:
Does not modify session room channels (responsibility of the next open). Does not modify session room channels (responsibility of the next open).
""" """
try: try:
conn = await self.bot.db.get_connection() # TODO: Transaction?
async with conn.transaction():
# Calculate rewards # Calculate rewards
rewards = [] rewards = []
attendance = [] attendance = []
@@ -532,8 +531,7 @@ class TimeSlot:
This involves refunding the booking transactions, deleting the booking rows, This involves refunding the booking transactions, deleting the booking rows,
and updating any messages that may have been posted. and updating any messages that may have been posted.
""" """
conn = await self.bot.db.get_connection() # TODO: Transaction
async with conn.transaction():
# Collect booking rows # Collect booking rows
bookings = [member.data for session in sessions for member in session.members.values()] bookings = [member.data for session in sessions for member in session.members.values()]

View File

@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
from meta import LionCog, LionContext, LionBot from meta import LionCog, LionContext, LionBot
from meta.errors import SafeCancellation from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils import ui from utils import ui
from utils.lib import error_embed from utils.lib import error_embed
from constants import MAX_COINS from constants import MAX_COINS
@@ -145,6 +146,7 @@ class ColourShop(Shop):
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance) if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
] ]
@log_wrap(action='purchase')
async def purchase(self, itemid) -> ColourRoleItem: async def purchase(self, itemid) -> ColourRoleItem:
""" """
Atomically handle a purchase of a ColourRoleItem. Atomically handle a purchase of a ColourRoleItem.
@@ -157,7 +159,8 @@ class ColourShop(Shop):
If the purchase fails for a known reason, raises SafeCancellation, with the error information. If the purchase fails for a known reason, raises SafeCancellation, with the error information.
""" """
t = self.bot.translator.t t = self.bot.translator.t
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
# Retrieve the item to purchase from data # Retrieve the item to purchase from data
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid) item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)

View File

@@ -4,6 +4,7 @@ from enum import Enum
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table, RegisterEnum from data import RowModel, Registry, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Bool, Column from data.columns import Integer, String, Timestamp, Bool, Column
@@ -80,6 +81,7 @@ class StatsData(Registry):
end_time = Timestamp() end_time = Timestamp()
@classmethod @classmethod
@log_wrap(action='tracked_time_between')
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]): async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
query = sql.SQL( query = sql.SQL(
""" """
@@ -103,7 +105,7 @@ class StatsData(Registry):
for _ in points for _ in points
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -112,8 +114,9 @@ class StatsData(Registry):
return cursor.fetchall() return cursor.fetchall()
@classmethod @classmethod
@log_wrap(action='study_time_between')
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int: async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_between(%s, %s, %s, %s)", "SELECT study_time_between(%s, %s, %s, %s)",
@@ -122,6 +125,7 @@ class StatsData(Registry):
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone()[0]) or 0
@classmethod @classmethod
@log_wrap(action='study_times_between')
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]: async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
if len(points) < 2: if len(points) < 2:
raise ValueError('Not enough block points given!') raise ValueError('Not enough block points given!')
@@ -141,7 +145,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -150,8 +154,9 @@ class StatsData(Registry):
return [r['stime'] or 0 for r in await cursor.fetchall()] return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int: async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_since(%s, %s, %s)", "SELECT study_time_since(%s, %s, %s)",
@@ -160,6 +165,7 @@ class StatsData(Registry):
return (await cursor.fetchone()[0]) or 0 return (await cursor.fetchone()[0]) or 0
@classmethod @classmethod
@log_wrap(action='study_times_between')
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int: async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
if len(starts) < 1: if len(starts) < 1:
raise ValueError('No starting points given!') raise ValueError('No starting points given!')
@@ -178,7 +184,7 @@ class StatsData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in starts sql.SQL("({})").format(sql.Placeholder()) for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -187,6 +193,7 @@ class StatsData(Registry):
return [r['stime'] or 0 for r in await cursor.fetchall()] return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since): async def leaderboard_since(cls, guildid: int, since):
""" """
Return the voice totals since the given time for each member in the guild. Return the voice totals since the given time for each member in the guild.
@@ -226,7 +233,8 @@ class StatsData(Registry):
) )
second_query_args = (since, guildid, since, since) second_query_args = (since, guildid, since, since)
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(second_query, second_query_args) await cursor.execute(second_query, second_query_args)
@@ -243,6 +251,7 @@ class StatsData(Registry):
return leaderboard return leaderboard
@classmethod @classmethod
@log_wrap('leaderboard_all')
async def leaderboard_all(cls, guildid: int): async def leaderboard_all(cls, guildid: int):
""" """
Return the all-time voice totals for the given guild. Return the all-time voice totals for the given guild.
@@ -257,7 +266,7 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, )) await cursor.execute(query, (guildid, ))
leaderboard = [ leaderboard = [
@@ -296,6 +305,7 @@ class StatsData(Registry):
transactionid = Integer() transactionid = Integer()
@classmethod @classmethod
@log_wrap(action='xp_since')
async def xp_since(cls, guildid: int, userid: int, *starts): async def xp_since(cls, guildid: int, userid: int, *starts):
query = sql.SQL( query = sql.SQL(
""" """
@@ -320,7 +330,7 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts sql.Placeholder() for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -329,6 +339,7 @@ class StatsData(Registry):
return [r['exp'] or 0 for r in await cursor.fetchall()] return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='xp_between')
async def xp_between(cls, guildid: int, userid: int, *points): async def xp_between(cls, guildid: int, userid: int, *points):
blocks = zip(points, points[1:]) blocks = zip(points, points[1:])
query = sql.SQL( query = sql.SQL(
@@ -355,7 +366,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -364,6 +375,7 @@ class StatsData(Registry):
return [r['period_xp'] or 0 for r in await cursor.fetchall()] return [r['period_xp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since): async def leaderboard_since(cls, guildid: int, since):
""" """
Return the XP totals for the given guild since the given time. Return the XP totals for the given guild since the given time.
@@ -378,7 +390,7 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since)) await cursor.execute(query, (guildid, since))
leaderboard = [ leaderboard = [
@@ -388,6 +400,7 @@ class StatsData(Registry):
return leaderboard return leaderboard
@classmethod @classmethod
@log_wrap(action='leaderboard_all')
async def leaderboard_all(cls, guildid: int): async def leaderboard_all(cls, guildid: int):
""" """
Return the all-time XP totals for the given guild. Return the all-time XP totals for the given guild.
@@ -402,7 +415,7 @@ class StatsData(Registry):
""" """
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, )) await cursor.execute(query, (guildid, ))
leaderboard = [ leaderboard = [
@@ -436,6 +449,7 @@ class StatsData(Registry):
exp_type: Column[ExpType] = Column() exp_type: Column[ExpType] = Column()
@classmethod @classmethod
@log_wrap(action='user_xp_since')
async def xp_since(cls, userid: int, *starts): async def xp_since(cls, userid: int, *starts):
query = sql.SQL( query = sql.SQL(
""" """
@@ -459,7 +473,7 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts sql.Placeholder() for _ in starts
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -468,6 +482,7 @@ class StatsData(Registry):
return [r['exp'] or 0 for r in await cursor.fetchall()] return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='user_xp_since')
async def xp_between(cls, userid: int, *points): async def xp_between(cls, userid: int, *points):
blocks = zip(points, points[1:]) blocks = zip(points, points[1:])
query = sql.SQL( query = sql.SQL(
@@ -493,7 +508,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -531,8 +546,10 @@ class StatsData(Registry):
return [tag.tag for tag in tags] return [tag.tag for tag in tags]
@classmethod @classmethod
@log_wrap(action='set_profile_tags')
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]): async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
conn = await self._connector.get_connection() async with self._connector.connection() as conn:
self._connector.conn = conn
async with conn.transaction(): async with conn.transaction():
await self.table.delete_where(guildid=guildid, userid=userid) await self.table.delete_where(guildid=guildid, userid=userid)
if tags: if tags:

View File

@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
# Update the tasklist # Update the tasklist
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)): if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
modified = True modified = True
conn = await self.bot.db.get_connection() async with self._connector.connection() as conn:
async with conn.transaction(): async with conn.transaction():
await tasks_model.table.delete_where(**key) await tasks_model.table.delete_where(**key).with_connection(conn)
if new_tasks: if new_tasks:
await tasks_model.table.insert_many( await tasks_model.table.insert_many(
(*key.keys(), 'completed', 'content'), (*key.keys(), 'completed', 'content'),
*((*key.values(), *new_task) for new_task in new_tasks) *((*key.values(), *new_task) for new_task in new_tasks)
) ).with_connection(conn)
if modified: if modified:
# If either goal type was modified, clear the rendered cache and refresh # If either goal type was modified, clear the rendered cache and refresh

View File

@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from meta import LionBot, LionCog, LionContext from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import UserInputError from meta.errors import UserInputError
from utils.lib import utc_now, error_embed from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton from utils.ui import ChoicedEnum, Transformed, AButton
@@ -141,8 +142,10 @@ class TasklistCog(LionCog):
self.crossload_group(self.configure_group, configcog.configure_group) self.crossload_group(self.configure_group, configcog.configure_group)
@LionCog.listener('on_tasks_completed') @LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int): async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
tasklist = await Tasklist.fetch(self.bot, self.data, member.id) tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
tasks = await tasklist.fetch_tasks(*taskids) tasks = await tasklist.fetch_tasks(*taskids)
@@ -477,11 +480,8 @@ class TasklistCog(LionCog):
# Contents successfully parsed, update the tasklist. # Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id) tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
# Lazily using the editor because it has a good parser
taskinfo = tasklist.parse_tasklist(lines) taskinfo = tasklist.parse_tasklist(lines)
conn = await self.bot.db.get_connection()
async with conn.transaction():
now = utc_now() now = utc_now()
# Delete tasklist if required # Delete tasklist if required

View File

@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
from discord.ui.text_input import TextInput, TextStyle from discord.ui.text_input import TextInput, TextStyle
from meta import conf from meta import conf
from meta.logger import log_wrap
from meta.errors import UserInputError from meta.errors import UserInputError
from utils.lib import MessageArgs, utc_now from utils.lib import MessageArgs, utc_now
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
except UserInputError as error: except UserInputError as error:
await ModalRetryUI(self, error.msg).respond_to(interaction) await ModalRetryUI(self, error.msg).respond_to(interaction)
@log_wrap(action="parse editor")
async def parse_editor(self): async def parse_editor(self):
# First parse each line # First parse each line
new_lines = self.tasklist_editor.value.splitlines() new_lines = self.tasklist_editor.value.splitlines()
@@ -155,7 +157,8 @@ class BulkEditor(LeoModal):
) )
# TODO: Incremental/diff editing # TODO: Incremental/diff editing
conn = await self.bot.db.get_connection() async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction(): async with conn.transaction():
now = utc_now() now = utc_now()

View File

@@ -2,6 +2,7 @@ from typing import Type
import json import json
from data import RowModel, Table, ORDER from data import RowModel, Table, ORDER
from meta.logger import log_wrap, set_logging_context
class ModelData: class ModelData:
@@ -60,6 +61,7 @@ class ModelData:
It only updates. It only updates.
""" """
# TODO: Better way of getting the key? # TODO: Better way of getting the key?
# TODO: Transaction
if not isinstance(parent_id, tuple): if not isinstance(parent_id, tuple):
parent_id = (parent_id, ) parent_id = (parent_id, )
model = cls._model model = cls._model
@@ -83,6 +85,8 @@ class ListData:
This assumes the list is the only data stored in the table, This assumes the list is the only data stored in the table,
and removes list entries by deleting rows. and removes list entries by deleting rows.
""" """
setting_id: str
# Table storing the setting data # Table storing the setting data
_table_interface: Table _table_interface: Table
@@ -100,10 +104,12 @@ class ListData:
_cache = None # Map[id -> value] _cache = None # Map[id -> value]
@classmethod @classmethod
@log_wrap(isolate=True)
async def _reader(cls, parent_id, use_cache=True, **kwargs): async def _reader(cls, parent_id, use_cache=True, **kwargs):
""" """
Read in all entries associated to the given id. Read in all entries associated to the given id.
""" """
set_logging_context(action="Read cls.setting_id")
if cls._cache is not None and parent_id in cls._cache and use_cache: if cls._cache is not None and parent_id in cls._cache and use_cache:
return cls._cache[parent_id] return cls._cache[parent_id]
@@ -121,12 +127,15 @@ class ListData:
return data return data
@classmethod @classmethod
@log_wrap(isolate=True)
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs): async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
""" """
Write the provided list to storage. Write the provided list to storage.
""" """
set_logging_context(action="Write cls.setting_id")
table = cls._table_interface table = cls._table_interface
conn = await table.connector.get_connection() async with table.connector.connection() as conn:
table.connector.conn = conn
async with conn.transaction(): async with conn.transaction():
# Handle None input as an empty list # Handle None input as an empty list
if data is None: if data is None:

View File

@@ -1,7 +1,7 @@
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool from data.columns import Integer, String, Timestamp, Bool
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
member_expid = Integer() member_expid = Integer()
@classmethod @classmethod
@log_wrap(action='end_text_sessions')
async def end_sessions(cls, connector, *session_data): async def end_sessions(cls, connector, *session_data):
query = sql.SQL(""" query = sql.SQL("""
WITH WITH
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
) SELECT ) SELECT
data._guildid, 0, data._guildid, 0,
NULL, data._userid, NULL, data._userid,
SUM(_coins), 0, 'TEXT_SESSION' LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
FROM data FROM data
WHERE data._coins > 0 WHERE data._coins > 0
GROUP BY (data._guildid, data._userid) GROUP BY (data._guildid, data._userid)
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
) )
, member AS ( , member AS (
UPDATE members UPDATE members
SET coins = coins + data._coins SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
FROM data FROM data
WHERE members.userid = data._userid AND members.guildid = data._guildid WHERE members.userid = data._userid AND members.guildid = data._guildid
) )
@@ -166,7 +167,7 @@ class TextTrackerData(Registry):
# Or ask for a connection from the connection pool # Or ask for a connection from the connection pool
# Transaction may take some time due to index updates # Transaction may take some time due to index updates
# Alternatively maybe use the "do not expect response mode" # Alternatively maybe use the "do not expect response mode"
conn = await connector.get_connection() async with connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -175,6 +176,7 @@ class TextTrackerData(Registry):
return return
@classmethod @classmethod
@log_wrap(action='user_messages_between')
async def user_messages_between(cls, userid: int, *points): async def user_messages_between(cls, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -203,7 +205,7 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -212,6 +214,7 @@ class TextTrackerData(Registry):
return [r['period_m'] or 0 for r in await cursor.fetchall()] return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='member_messages_between')
async def member_messages_between(cls, guildid: int, userid: int, *points): async def member_messages_between(cls, guildid: int, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -241,7 +244,7 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:] sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -250,6 +253,7 @@ class TextTrackerData(Registry):
return [r['period_m'] or 0 for r in await cursor.fetchall()] return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod @classmethod
@log_wrap(action='member_messages_since')
async def member_messages_since(cls, guildid: int, userid: int, *points): async def member_messages_since(cls, guildid: int, userid: int, *points):
""" """
Compute messages written between the given points. Compute messages written between the given points.
@@ -277,7 +281,7 @@ class TextTrackerData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in points sql.SQL("({})").format(sql.Placeholder()) for _ in points
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,

View File

@@ -2,6 +2,7 @@ import datetime as dt
from itertools import chain from itertools import chain
from psycopg import sql from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool from data.columns import Integer, String, Timestamp, Bool
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
hourly_coins = Integer() hourly_coins = Integer()
@classmethod @classmethod
@log_wrap(action='close_voice_session')
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int: async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT close_study_session_at(%s, %s, %s)", "SELECT close_study_session_at(%s, %s, %s)",
(guildid, userid, _at) (guildid, userid, _at)
) )
member_data = await cursor.fetchone() return await cursor.fetchone()
@classmethod @classmethod
@log_wrap(action='close_voice_sessions')
async def close_voice_sessions_at(cls, *arg_tuples): async def close_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
SELECT SELECT
@@ -139,7 +142,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -147,11 +150,12 @@ class VoiceTrackerData(Registry):
) )
@classmethod @classmethod
@log_wrap(action='update_voice_session')
async def update_voice_session_at( async def update_voice_session_at(
cls, guildid: int, userid: int, _at: dt.datetime, cls, guildid: int, userid: int, _at: dt.datetime,
stream: bool, video: bool, rate: float stream: bool, video: bool, rate: float
) -> int: ) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)", "SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
@@ -161,6 +165,7 @@ class VoiceTrackerData(Registry):
return cls._make_rows(*rows) return cls._make_rows(*rows)
@classmethod @classmethod
@log_wrap(action='update_voice_sessions')
async def update_voice_sessions_at(cls, *arg_tuples): async def update_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
UPDATE UPDATE
@@ -209,7 +214,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,
@@ -257,8 +262,9 @@ class VoiceTrackerData(Registry):
transactionid = Integer() transactionid = Integer()
@classmethod @classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int: async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
"SELECT study_time_since(%s, %s, %s) AS result", "SELECT study_time_since(%s, %s, %s) AS result",
@@ -268,6 +274,7 @@ class VoiceTrackerData(Registry):
return (result['result'] or 0) if result else 0 return (result['result'] or 0) if result else 0
@classmethod @classmethod
@log_wrap(action='multiple_voice_tracked_since')
async def multiple_voice_tracked_since(cls, *arg_tuples): async def multiple_voice_tracked_since(cls, *arg_tuples):
query = sql.SQL(""" query = sql.SQL("""
SELECT SELECT
@@ -286,7 +293,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples for _ in arg_tuples
) )
) )
conn = await cls._connector.get_connection() async with cls._connector.connection() as conn:
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
await cursor.execute( await cursor.execute(
query, query,

View File

@@ -161,7 +161,6 @@ class VoiceSession:
self.state.channelid, guildid=self.guildid, deleted=False self.state.channelid, guildid=self.guildid, deleted=False
) )
conn = await self.bot.db.get_connection()
# Insert an ongoing_session with the correct state, set data # Insert an ongoing_session with the correct state, set data
state = self.state state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create( self.data = await self.registry.VoiceSessionsOngoing.create(