fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
)
@log_wrap(action='consumer', isolate=False)
@log_wrap(action='consumer')
async def consumer(self):
while True:
try:
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
)
pass
@log_wrap(action='batch', isolate=False)
@log_wrap(action='batch')
async def process_batch(self):
logger.debug("Processing Batch")
# TODO: copy syntax might be more efficient here

View File

@@ -123,7 +123,7 @@ class AnalyticsServer:
log_action_stack.set(['Analytics'])
log_app.set(conf.analytics['appname'])
async with await self.db.connect():
async with self.db.open():
await self.talk.connect()
await self.attach_event_handlers()
self._snap_task = asyncio.create_task(self.snapshot_loop())

View File

@@ -38,7 +38,7 @@ async def main():
intents.message_content = True
intents.presences = False
async with await db.connect():
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."

View File

@@ -57,8 +57,6 @@ class CoreCog(LionCog):
async def cog_load(self):
# Fetch (and possibly create) core data rows.
conn = await self.bot.db.get_connection()
async with conn.transaction():
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.shard_data = await self.data.Shard.fetch_or_create(

View File

@@ -5,6 +5,7 @@ from cachetools import TTLCache
import discord
from meta import conf
from meta.logger import log_wrap
from data import Table, Registry, Column, RowModel, RegisterEnum
from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
_timestamp = Timestamp()
@classmethod
@log_wrap(action="Add Pending Coins")
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
"""
Safely add pending coins to a list of members.
@@ -316,7 +318,7 @@ class CoreData(Registry, name="core"):
)
)
# TODO: Replace with copy syntax/query?
conn = await cls.table.connector.get_connection()
async with cls.table.connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -326,11 +328,12 @@ class CoreData(Registry, name="core"):
return cls._make_rows(*rows)
@classmethod
@log_wrap(action='get_member_rank')
async def get_member_rank(cls, guildid, userid, untracked):
"""
Get the time and coin ranking for the given member, ignoring the provided untracked members.
"""
conn = await cls.table.connector.get_connection()
async with cls.table.connector.connection() as conn:
async with conn.cursor() as curs:
await curs.execute(
"""

View File

@@ -1,6 +1,7 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection()
if connection is None:
raise ValueError("Cannot Init without connection.")
info = await EnumInfo.fetch(connection, self.name)
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Get the current active connection.
This should never be cached outside of a transaction.
Convenience property for the current context connection.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
return ctx_connection.get()
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""

View File

@@ -35,7 +35,8 @@ class Database(Connector):
"""
Return the current schema version as a Version namedtuple.
"""
async with self.conn.cursor() as cursor:
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()

View File

@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
conn = await self.connector.get_connection()
else:
conn = self.conn
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union
import asyncio
import discord
from discord.ext import commands as cmds
@@ -182,7 +183,9 @@ class Economy(LionCog):
# We may need to do a mass row create operation.
targetids = set(target.id for target in targets)
if len(targets) > 1:
conn = await ctx.bot.db.get_connection()
async def wrapper():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# First fetch the members which currently exist
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
@@ -206,6 +209,8 @@ class Economy(LionCog):
('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True)
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
await task
else:
# With only one target, we can take a simpler path, and make better use of local caches.
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
@@ -703,7 +708,9 @@ class Economy(LionCog):
# Alternative flow could be waiting until the target user presses accept
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
conn = await self.bot.db.get_connection()
async def wrapped():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# We do this in a transaction so that if something goes wrong,
# the coins deduction is rolled back atomicly
@@ -728,6 +735,7 @@ class Economy(LionCog):
await target_lion.data.update(coins=(Member.coins + amount))
# TODO: Audit trail
await asyncio.create_task(wrapped(), name="wrapped-send")
# Message target
embed = discord.Embed(

View File

@@ -1,6 +1,7 @@
from enum import Enum
from psycopg import sql
from meta.logger import log_wrap
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
from data.columns import Integer, Bool, Column, Timestamp
from core.data import CoreData
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
created_at = Timestamp()
@classmethod
@log_wrap(action='execute_transaction')
async def execute_transaction(
cls,
transaction_type: TransactionType,
@@ -108,7 +110,8 @@ class EconomyData(Registry, name='economy'):
from_account: int, to_account: int, amount: int, bonus: int = 0,
refunds: int = None
):
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
transaction = await cls.create(
transactiontype=transaction_type,
@@ -127,6 +130,7 @@ class EconomyData(Registry, name='economy'):
return transaction
@classmethod
@log_wrap(action='execute_transactions')
async def execute_transactions(cls, *transactions):
"""
Execute multiple transactions in one data transaction.
@@ -142,7 +146,8 @@ class EconomyData(Registry, name='economy'):
if not transactions:
return []
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Create the transactions
rows = await cls.table.insert_many(
@@ -180,10 +185,12 @@ class EconomyData(Registry, name='economy'):
return rows
@classmethod
@log_wrap(action='refund_transactions')
async def refund_transactions(cls, *transactionids, actorid=0):
if not transactionids:
return []
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# First fetch the transaction rows to refund
data = await cls.table.select_where(transactionid=transactionids)
@@ -217,12 +224,14 @@ class EconomyData(Registry, name='economy'):
itemid = Integer()
@classmethod
@log_wrap(action='purchase_transaction')
async def purchase_transaction(
cls,
guildid: int, actorid: int,
userid: int, itemid: int, amount: int
):
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.SHOP_PURCHASE,
@@ -263,12 +272,14 @@ class EconomyData(Registry, name='economy'):
return result[0]['recent'] or 0
@classmethod
@log_wrap(action='reward_completed_tasks')
async def reward_completed(cls, userid, guildid, count, amount):
"""
Reward the specified member `amount` coins for completing `count` tasks.
"""
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.TASKS,

View File

@@ -193,7 +193,8 @@ class MemberAdminCog(LionCog):
await lion.data.update(last_left=utc_now())
# Save member roles
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
await self.data.past_roles.delete_where(
guildid=member.guild.id,

View File

@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
update_args[instance._column] = instance.data
ack_lines.append(instance.update_message)
await ctx.lguild.data.update(**update_args)
# Do the ack
tick = self.bot.config.emojis.tick
embed = discord.Embed(

View File

@@ -483,8 +483,6 @@ class RoleMenu:
)).format(role=role.name)
)
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Remove the role
try:
await member.remove_roles(role)
@@ -591,8 +589,6 @@ class RoleMenu:
)
)
conn = await self.bot.db.get_connection()
async with conn.transaction():
try:
await member.add_roles(role)
except discord.Forbidden:

View File

@@ -259,17 +259,6 @@ class RoomCog(LionCog):
lguild,
[member.id for member in members]
)
self._start(room)
# Send tips message
# TODO: Actual tips.
await channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
await ui.send(channel)
except Exception:
try:
await channel.delete(reason="Failed to created private room")
@@ -454,7 +443,42 @@ class RoomCog(LionCog):
return
# Positive response. Start a transaction.
conn = await self.bot.db.get_connection()
room = await self._do_create_room(ctx, required, days, rent, name, provided)
if room:
# Ack with confirmation message pointing to the room
msg = t(_p(
'cmd:room_rent|success',
"Successfully created your private room {channel}!"
)).format(channel=room.channel.mention)
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
description=msg
)
)
self._start(room)
# Send tips message
# TODO: Actual tips.
await room.channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
mention=ctx.author.mention
)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
await ui.send(room.channel)
@log_wrap(action='create_room')
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
t = self.bot.translator.t
# TODO: Rollback the channel create if this fails
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
# Note that the room creation will go into the UI as well.
async with conn.transaction():
# Check member balance is sufficient
await ctx.alion.data.refresh()
@@ -486,7 +510,7 @@ class RoomCog(LionCog):
# Create room with given starting balance and other parameters
try:
room = await self.create_private_room(
return await self.create_private_room(
ctx.guild,
ctx.author,
required - rent,
@@ -519,19 +543,6 @@ class RoomCog(LionCog):
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
# Ack with confirmation message pointing to the room
msg = t(_p(
'cmd:room_rent|success',
"Successfully created your private room {channel}!"
)).format(channel=room.channel.mention)
await ctx.reply(
embed=discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
description=msg
)
)
@room_group.command(
name=_p('cmd:room_status', "status"),
description=_p(
@@ -864,8 +875,7 @@ class RoomCog(LionCog):
return
# Start Transaction
conn = await self.bot.db.get_connection()
async with conn.transaction():
# TODO: Economy transaction
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < coins:
@@ -883,7 +893,6 @@ class RoomCog(LionCog):
return
# Deduct balance
# TODO: Economy transaction
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)

View File

@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
from meta import LionBot, conf
from meta.errors import UserInputError
from meta.logger import log_wrap
from babel.translator import ctx_locale
from utils.lib import utc_now, MessageArgs, error_embed
from utils.ui import MessageUI, input
@@ -115,8 +116,18 @@ class RoomUI(MessageUI):
return
await submit.response.defer(thinking=True, ephemeral=True)
await self._do_deposit(t, press, amount, submit)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
@log_wrap(isolate=True)
async def _do_deposit(self, t, press, amount, submit):
# Start transaction for deposit
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Get the lion balance directly
lion = await self.bot.core.data.Member.fetch(
@@ -143,11 +154,6 @@ class RoomUI(MessageUI):
await lion.update(coins=CoreData.Member.coins - amount)
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
async def desposit_button_refresh(self):
self.desposit_button.label = self.bot.translator.t(_p(
'ui:room_status|button:deposit|label',

View File

@@ -217,7 +217,8 @@ class ScheduleCog(LionCog):
for bookingid in bookingids:
await self._cancel_booking_active(*bookingid)
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Now delete from data
records = await self.data.ScheduleSessionMember.table.delete_where(
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
"One or more requested timeslots are already booked!"
))
raise UserInputError(error)
conn = await self.bot.db.get_connection()
# Booking request is now validated. Perform bookings.
# Fetch or create session data
@@ -482,8 +482,8 @@ class ScheduleCog(LionCog):
*((guildid, slotid) for slotid in slotids)
)
async with conn.transaction():
# Create transactions
# TODO: wrap in a transaction so the economy transaction gets unwound if it fails
economy = self.bot.get_cog('Economy')
trans_data = (
TransactionType.SCHEDULE_BOOK,

View File

@@ -356,8 +356,7 @@ class TimeSlot:
Does not modify session room channels (responsibility of the next open).
"""
try:
conn = await self.bot.db.get_connection()
async with conn.transaction():
# TODO: Transaction?
# Calculate rewards
rewards = []
attendance = []
@@ -532,8 +531,7 @@ class TimeSlot:
This involves refunding the booking transactions, deleting the booking rows,
and updating any messages that may have been posted.
"""
conn = await self.bot.db.get_connection()
async with conn.transaction():
# TODO: Transaction
# Collect booking rows
bookings = [member.data for session in sessions for member in session.members.values()]

View File

@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
from meta import LionCog, LionContext, LionBot
from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils import ui
from utils.lib import error_embed
from constants import MAX_COINS
@@ -145,6 +146,7 @@ class ColourShop(Shop):
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
]
@log_wrap(action='purchase')
async def purchase(self, itemid) -> ColourRoleItem:
"""
Atomically handle a purchase of a ColourRoleItem.
@@ -157,7 +159,8 @@ class ColourShop(Shop):
If the purchase fails for a known reason, raises SafeCancellation, with the error information.
"""
t = self.bot.translator.t
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Retrieve the item to purchase from data
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)

View File

@@ -4,6 +4,7 @@ from enum import Enum
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Bool, Column
@@ -80,6 +81,7 @@ class StatsData(Registry):
end_time = Timestamp()
@classmethod
@log_wrap(action='tracked_time_between')
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
query = sql.SQL(
"""
@@ -103,7 +105,7 @@ class StatsData(Registry):
for _ in points
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -112,8 +114,9 @@ class StatsData(Registry):
return cursor.fetchall()
@classmethod
@log_wrap(action='study_time_between')
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_between(%s, %s, %s, %s)",
@@ -122,6 +125,7 @@ class StatsData(Registry):
return (await cursor.fetchone()[0]) or 0
@classmethod
@log_wrap(action='study_times_between')
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
if len(points) < 2:
raise ValueError('Not enough block points given!')
@@ -141,7 +145,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -150,8 +154,9 @@ class StatsData(Registry):
return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s)",
@@ -160,6 +165,7 @@ class StatsData(Registry):
return (await cursor.fetchone()[0]) or 0
@classmethod
@log_wrap(action='study_times_between')
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
if len(starts) < 1:
raise ValueError('No starting points given!')
@@ -178,7 +184,7 @@ class StatsData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in starts
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -187,6 +193,7 @@ class StatsData(Registry):
return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the voice totals since the given time for each member in the guild.
@@ -226,7 +233,8 @@ class StatsData(Registry):
)
second_query_args = (since, guildid, since, since)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
async with conn.cursor() as cursor:
await cursor.execute(second_query, second_query_args)
@@ -243,6 +251,7 @@ class StatsData(Registry):
return leaderboard
@classmethod
@log_wrap('leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time voice totals for the given guild.
@@ -257,7 +266,7 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
@@ -296,6 +305,7 @@ class StatsData(Registry):
transactionid = Integer()
@classmethod
@log_wrap(action='xp_since')
async def xp_since(cls, guildid: int, userid: int, *starts):
query = sql.SQL(
"""
@@ -320,7 +330,7 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -329,6 +339,7 @@ class StatsData(Registry):
return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='xp_between')
async def xp_between(cls, guildid: int, userid: int, *points):
blocks = zip(points, points[1:])
query = sql.SQL(
@@ -355,7 +366,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -364,6 +375,7 @@ class StatsData(Registry):
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the XP totals for the given guild since the given time.
@@ -378,7 +390,7 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since))
leaderboard = [
@@ -388,6 +400,7 @@ class StatsData(Registry):
return leaderboard
@classmethod
@log_wrap(action='leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time XP totals for the given guild.
@@ -402,7 +415,7 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
@@ -436,6 +449,7 @@ class StatsData(Registry):
exp_type: Column[ExpType] = Column()
@classmethod
@log_wrap(action='user_xp_since')
async def xp_since(cls, userid: int, *starts):
query = sql.SQL(
"""
@@ -459,7 +473,7 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -468,6 +482,7 @@ class StatsData(Registry):
return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='user_xp_since')
async def xp_between(cls, userid: int, *points):
blocks = zip(points, points[1:])
query = sql.SQL(
@@ -493,7 +508,7 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -531,8 +546,10 @@ class StatsData(Registry):
return [tag.tag for tag in tags]
@classmethod
@log_wrap(action='set_profile_tags')
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
conn = await self._connector.get_connection()
async with self._connector.connection() as conn:
self._connector.conn = conn
async with conn.transaction():
await self.table.delete_where(guildid=guildid, userid=userid)
if tags:

View File

@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
# Update the tasklist
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
modified = True
conn = await self.bot.db.get_connection()
async with self._connector.connection() as conn:
async with conn.transaction():
await tasks_model.table.delete_where(**key)
await tasks_model.table.delete_where(**key).with_connection(conn)
if new_tasks:
await tasks_model.table.insert_many(
(*key.keys(), 'completed', 'content'),
*((*key.values(), *new_task) for new_task in new_tasks)
)
).with_connection(conn)
if modified:
# If either goal type was modified, clear the rendered cache and refresh

View File

@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import UserInputError
from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton
@@ -141,8 +142,10 @@ class TasklistCog(LionCog):
self.crossload_group(self.configure_group, configcog.configure_group)
@LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
tasks = await tasklist.fetch_tasks(*taskids)
@@ -477,11 +480,8 @@ class TasklistCog(LionCog):
# Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
# Lazily using the editor because it has a good parser
taskinfo = tasklist.parse_tasklist(lines)
conn = await self.bot.db.get_connection()
async with conn.transaction():
now = utc_now()
# Delete tasklist if required

View File

@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
from discord.ui.text_input import TextInput, TextStyle
from meta import conf
from meta.logger import log_wrap
from meta.errors import UserInputError
from utils.lib import MessageArgs, utc_now
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
except UserInputError as error:
await ModalRetryUI(self, error.msg).respond_to(interaction)
@log_wrap(action="parse editor")
async def parse_editor(self):
# First parse each line
new_lines = self.tasklist_editor.value.splitlines()
@@ -155,7 +157,8 @@ class BulkEditor(LeoModal):
)
# TODO: Incremental/diff editing
conn = await self.bot.db.get_connection()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
now = utc_now()

View File

@@ -2,6 +2,7 @@ from typing import Type
import json
from data import RowModel, Table, ORDER
from meta.logger import log_wrap, set_logging_context
class ModelData:
@@ -60,6 +61,7 @@ class ModelData:
It only updates.
"""
# TODO: Better way of getting the key?
# TODO: Transaction
if not isinstance(parent_id, tuple):
parent_id = (parent_id, )
model = cls._model
@@ -83,6 +85,8 @@ class ListData:
This assumes the list is the only data stored in the table,
and removes list entries by deleting rows.
"""
setting_id: str
# Table storing the setting data
_table_interface: Table
@@ -100,10 +104,12 @@ class ListData:
_cache = None # Map[id -> value]
@classmethod
@log_wrap(isolate=True)
async def _reader(cls, parent_id, use_cache=True, **kwargs):
"""
Read in all entries associated to the given id.
"""
set_logging_context(action="Read cls.setting_id")
if cls._cache is not None and parent_id in cls._cache and use_cache:
return cls._cache[parent_id]
@@ -121,12 +127,15 @@ class ListData:
return data
@classmethod
@log_wrap(isolate=True)
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
"""
Write the provided list to storage.
"""
set_logging_context(action="Write cls.setting_id")
table = cls._table_interface
conn = await table.connector.get_connection()
async with table.connector.connection() as conn:
table.connector.conn = conn
async with conn.transaction():
# Handle None input as an empty list
if data is None:

View File

@@ -1,7 +1,7 @@
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
member_expid = Integer()
@classmethod
@log_wrap(action='end_text_sessions')
async def end_sessions(cls, connector, *session_data):
query = sql.SQL("""
WITH
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
) SELECT
data._guildid, 0,
NULL, data._userid,
SUM(_coins), 0, 'TEXT_SESSION'
LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
FROM data
WHERE data._coins > 0
GROUP BY (data._guildid, data._userid)
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
)
, member AS (
UPDATE members
SET coins = coins + data._coins
SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
FROM data
WHERE members.userid = data._userid AND members.guildid = data._guildid
)
@@ -166,7 +167,7 @@ class TextTrackerData(Registry):
# Or ask for a connection from the connection pool
# Transaction may take some time due to index updates
# Alternatively maybe use the "do not expect response mode"
conn = await connector.get_connection()
async with connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -175,6 +176,7 @@ class TextTrackerData(Registry):
return
@classmethod
@log_wrap(action='user_messages_between')
async def user_messages_between(cls, userid: int, *points):
"""
Compute messages written between the given points.
@@ -203,7 +205,7 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -212,6 +214,7 @@ class TextTrackerData(Registry):
return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='member_messages_between')
async def member_messages_between(cls, guildid: int, userid: int, *points):
"""
Compute messages written between the given points.
@@ -241,7 +244,7 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -250,6 +253,7 @@ class TextTrackerData(Registry):
return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='member_messages_since')
async def member_messages_since(cls, guildid: int, userid: int, *points):
"""
Compute messages written between the given points.
@@ -277,7 +281,7 @@ class TextTrackerData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in points
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,

View File

@@ -2,6 +2,7 @@ import datetime as dt
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
hourly_coins = Integer()
@classmethod
@log_wrap(action='close_voice_session')
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT close_study_session_at(%s, %s, %s)",
(guildid, userid, _at)
)
member_data = await cursor.fetchone()
return await cursor.fetchone()
@classmethod
@log_wrap(action='close_voice_sessions')
async def close_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL("""
SELECT
@@ -139,7 +142,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -147,11 +150,12 @@ class VoiceTrackerData(Registry):
)
@classmethod
@log_wrap(action='update_voice_session')
async def update_voice_session_at(
cls, guildid: int, userid: int, _at: dt.datetime,
stream: bool, video: bool, rate: float
) -> int:
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
@@ -161,6 +165,7 @@ class VoiceTrackerData(Registry):
return cls._make_rows(*rows)
@classmethod
@log_wrap(action='update_voice_sessions')
async def update_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL("""
UPDATE
@@ -209,7 +214,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
@@ -257,8 +262,9 @@ class VoiceTrackerData(Registry):
transactionid = Integer()
@classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s) AS result",
@@ -268,6 +274,7 @@ class VoiceTrackerData(Registry):
return (result['result'] or 0) if result else 0
@classmethod
@log_wrap(action='multiple_voice_tracked_since')
async def multiple_voice_tracked_since(cls, *arg_tuples):
query = sql.SQL("""
SELECT
@@ -286,7 +293,7 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,

View File

@@ -161,7 +161,6 @@ class VoiceSession:
self.state.channelid, guildid=self.guildid, deleted=False
)
conn = await self.bot.db.get_connection()
# Insert an ongoing_session with the correct state, set data
state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create(