fix (data): Parallel connection pool.
This commit is contained in:
@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
|
||||
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
|
||||
)
|
||||
|
||||
@log_wrap(action='consumer', isolate=False)
|
||||
@log_wrap(action='consumer')
|
||||
async def consumer(self):
|
||||
while True:
|
||||
try:
|
||||
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
|
||||
)
|
||||
pass
|
||||
|
||||
@log_wrap(action='batch', isolate=False)
|
||||
@log_wrap(action='batch')
|
||||
async def process_batch(self):
|
||||
logger.debug("Processing Batch")
|
||||
# TODO: copy syntax might be more efficient here
|
||||
|
||||
@@ -123,7 +123,7 @@ class AnalyticsServer:
|
||||
log_action_stack.set(['Analytics'])
|
||||
log_app.set(conf.analytics['appname'])
|
||||
|
||||
async with await self.db.connect():
|
||||
async with self.db.open():
|
||||
await self.talk.connect()
|
||||
await self.attach_event_handlers()
|
||||
self._snap_task = asyncio.create_task(self.snapshot_loop())
|
||||
|
||||
@@ -38,7 +38,7 @@ async def main():
|
||||
intents.message_content = True
|
||||
intents.presences = False
|
||||
|
||||
async with await db.connect():
|
||||
async with db.open():
|
||||
version = await db.version()
|
||||
if version.version != DATA_VERSION:
|
||||
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."
|
||||
|
||||
@@ -57,8 +57,6 @@ class CoreCog(LionCog):
|
||||
|
||||
async def cog_load(self):
|
||||
# Fetch (and possibly create) core data rows.
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
|
||||
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
|
||||
self.shard_data = await self.data.Shard.fetch_or_create(
|
||||
|
||||
@@ -5,6 +5,7 @@ from cachetools import TTLCache
|
||||
import discord
|
||||
|
||||
from meta import conf
|
||||
from meta.logger import log_wrap
|
||||
from data import Table, Registry, Column, RowModel, RegisterEnum
|
||||
from data.models import WeakCache
|
||||
from data.columns import Integer, String, Bool, Timestamp
|
||||
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
|
||||
_timestamp = Timestamp()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action="Add Pending Coins")
|
||||
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
|
||||
"""
|
||||
Safely add pending coins to a list of members.
|
||||
@@ -316,7 +318,7 @@ class CoreData(Registry, name="core"):
|
||||
)
|
||||
)
|
||||
# TODO: Replace with copy syntax/query?
|
||||
conn = await cls.table.connector.get_connection()
|
||||
async with cls.table.connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -326,11 +328,12 @@ class CoreData(Registry, name="core"):
|
||||
return cls._make_rows(*rows)
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='get_member_rank')
|
||||
async def get_member_rank(cls, guildid, userid, untracked):
|
||||
"""
|
||||
Get the time and coin ranking for the given member, ignoring the provided untracked members.
|
||||
"""
|
||||
conn = await cls.table.connector.get_connection()
|
||||
async with cls.table.connector.connection() as conn:
|
||||
async with conn.cursor() as curs:
|
||||
await curs.execute(
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# from enum import Enum
|
||||
from typing import Optional
|
||||
from psycopg.types.enum import register_enum, EnumInfo
|
||||
from psycopg import AsyncConnection
|
||||
from .registry import Attachable, Registry
|
||||
|
||||
|
||||
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
|
||||
connector = registry._conn
|
||||
if connector is None:
|
||||
raise ValueError("Cannot initialise without connector!")
|
||||
connection = await connector.get_connection()
|
||||
if connection is None:
|
||||
raise ValueError("Cannot Init without connection.")
|
||||
info = await EnumInfo.fetch(connection, self.name)
|
||||
connector.connect_hook(self.connection_hook)
|
||||
# await connector.refresh_pool()
|
||||
# The below may be somewhat dangerous
|
||||
# But adaption should never write to the database
|
||||
await connector.map_over_pool(self.connection_hook)
|
||||
# if conn := connector.conn:
|
||||
# # Ensure the adaption is run in the current context as well
|
||||
# await self.connection_hook(conn)
|
||||
|
||||
async def connection_hook(self, conn: AsyncConnection):
|
||||
info = await EnumInfo.fetch(conn, self.name)
|
||||
if info is None:
|
||||
raise ValueError(f"Enum {self.name} not found in database.")
|
||||
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
|
||||
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable
|
||||
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
|
||||
import logging
|
||||
|
||||
from contextvars import ContextVar
|
||||
from contextlib import asynccontextmanager
|
||||
import psycopg as psq
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
from psycopg.pq import TransactionStatus
|
||||
|
||||
from .cursor import AsyncLoggingCursor
|
||||
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
row_factory = psq.rows.dict_row
|
||||
|
||||
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
|
||||
|
||||
|
||||
class Connector:
|
||||
cursor_factory = AsyncLoggingCursor
|
||||
|
||||
def __init__(self, conn_args):
|
||||
self._conn_args = conn_args
|
||||
self.conn: psq.AsyncConnection = None
|
||||
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
|
||||
|
||||
self.pool = self.make_pool()
|
||||
|
||||
self.conn_hooks = []
|
||||
|
||||
async def get_connection(self) -> psq.AsyncConnection:
|
||||
@property
|
||||
def conn(self) -> Optional[psq.AsyncConnection]:
|
||||
"""
|
||||
Get the current active connection.
|
||||
This should never be cached outside of a transaction.
|
||||
Convenience property for the current context connection.
|
||||
"""
|
||||
# TODO: Reconnection logic?
|
||||
if not self.conn:
|
||||
raise ValueError("Attempting to get connection before initialisation!")
|
||||
if self.conn.info.transaction_status is TransactionStatus.INERROR:
|
||||
await self.connect()
|
||||
logger.error(
|
||||
"Database connection transaction failed!! This should not happen. Reconnecting."
|
||||
)
|
||||
return self.conn
|
||||
return ctx_connection.get()
|
||||
|
||||
async def connect(self) -> psq.AsyncConnection:
|
||||
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
|
||||
self.conn = await psq.AsyncConnection.connect(
|
||||
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
|
||||
@conn.setter
|
||||
def conn(self, conn: psq.AsyncConnection):
|
||||
"""
|
||||
Set the contextual connection in the current context.
|
||||
Always do this in an isolated context!
|
||||
"""
|
||||
ctx_connection.set(conn)
|
||||
|
||||
def make_pool(self) -> AsyncConnectionPool:
|
||||
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
|
||||
return AsyncConnectionPool(
|
||||
self._conn_args,
|
||||
open=False,
|
||||
min_size=4,
|
||||
max_size=8,
|
||||
configure=self._setup_connection,
|
||||
kwargs=self._conn_kwargs
|
||||
)
|
||||
|
||||
async def refresh_pool(self):
|
||||
"""
|
||||
Refresh the pool.
|
||||
|
||||
The point of this is to invalidate any existing connections so that the connection set up is run again.
|
||||
Better ways should be sought (a way to
|
||||
"""
|
||||
logger.info("Pool refresh requested, closing and reopening.")
|
||||
old_pool = self.pool
|
||||
self.pool = self.make_pool()
|
||||
await self.pool.open()
|
||||
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
|
||||
await old_pool.close()
|
||||
logger.info("Pool refresh complete.")
|
||||
|
||||
async def map_over_pool(self, callable):
|
||||
"""
|
||||
Dangerous method to call a method on each connection in the pool.
|
||||
|
||||
Utilises private methods of the AsyncConnectionPool.
|
||||
"""
|
||||
async with self.pool._lock:
|
||||
conns = list(self.pool._pool)
|
||||
while conns:
|
||||
conn = conns.pop()
|
||||
try:
|
||||
await callable(conn)
|
||||
except Exception:
|
||||
logger.exception(f"Mapped connection task failed. {callable.__name__}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def open(self):
|
||||
try:
|
||||
logger.info("Opening database pool.")
|
||||
await self.pool.open()
|
||||
yield
|
||||
finally:
|
||||
# May be a different pool!
|
||||
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
|
||||
await self.pool.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def connection(self) -> psq.AsyncConnection:
|
||||
"""
|
||||
Asynchronous context manager to get and manage a connection.
|
||||
|
||||
If the context connection is set, uses this and does not manage the lifetime.
|
||||
Otherwise, requests a new connection from the pool and returns it when done.
|
||||
"""
|
||||
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
|
||||
if (conn := self.conn):
|
||||
yield conn
|
||||
else:
|
||||
async with self.pool.connection() as conn:
|
||||
yield conn
|
||||
|
||||
async def _setup_connection(self, conn: psq.AsyncConnection):
|
||||
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
|
||||
for hook in self.conn_hooks:
|
||||
await hook(self.conn)
|
||||
return self.conn
|
||||
|
||||
async def reconnect(self) -> psq.AsyncConnection:
|
||||
return await self.connect()
|
||||
try:
|
||||
await hook(conn)
|
||||
except Exception:
|
||||
logger.exception("Exception encountered setting up new connection")
|
||||
return conn
|
||||
|
||||
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
|
||||
"""
|
||||
|
||||
@@ -35,7 +35,8 @@ class Database(Connector):
|
||||
"""
|
||||
Return the current schema version as a Version namedtuple.
|
||||
"""
|
||||
async with self.conn.cursor() as cursor:
|
||||
async with self.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Get last entry in version table, compare against desired version
|
||||
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
|
||||
row = await cursor.fetchone()
|
||||
|
||||
@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
|
||||
if self.connector is None:
|
||||
raise ValueError("Cannot execute query without cursor, connection, or connector.")
|
||||
else:
|
||||
conn = await self.connector.get_connection()
|
||||
else:
|
||||
conn = self.conn
|
||||
|
||||
async with self.connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
async with self.conn.cursor() as cursor:
|
||||
data = await self._execute(cursor)
|
||||
else:
|
||||
data = await self._execute(cursor)
|
||||
return data
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Optional, Union
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as cmds
|
||||
@@ -182,7 +183,9 @@ class Economy(LionCog):
|
||||
# We may need to do a mass row create operation.
|
||||
targetids = set(target.id for target in targets)
|
||||
if len(targets) > 1:
|
||||
conn = await ctx.bot.db.get_connection()
|
||||
async def wrapper():
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
# First fetch the members which currently exist
|
||||
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
|
||||
@@ -206,6 +209,8 @@ class Economy(LionCog):
|
||||
('guildid', 'userid', 'coins'),
|
||||
*((ctx.guild.id, id, 0) for id in new_ids)
|
||||
).on_conflict(ignore=True)
|
||||
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
|
||||
await task
|
||||
else:
|
||||
# With only one target, we can take a simpler path, and make better use of local caches.
|
||||
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
|
||||
@@ -703,7 +708,9 @@ class Economy(LionCog):
|
||||
# Alternative flow could be waiting until the target user presses accept
|
||||
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async def wrapped():
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
# We do this in a transaction so that if something goes wrong,
|
||||
# the coins deduction is rolled back atomicly
|
||||
@@ -728,6 +735,7 @@ class Economy(LionCog):
|
||||
await target_lion.data.update(coins=(Member.coins + amount))
|
||||
|
||||
# TODO: Audit trail
|
||||
await asyncio.create_task(wrapped(), name="wrapped-send")
|
||||
|
||||
# Message target
|
||||
embed = discord.Embed(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
from psycopg import sql
|
||||
from meta.logger import log_wrap
|
||||
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
|
||||
from data.columns import Integer, Bool, Column, Timestamp
|
||||
from core.data import CoreData
|
||||
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
|
||||
created_at = Timestamp()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='execute_transaction')
|
||||
async def execute_transaction(
|
||||
cls,
|
||||
transaction_type: TransactionType,
|
||||
@@ -108,7 +110,8 @@ class EconomyData(Registry, name='economy'):
|
||||
from_account: int, to_account: int, amount: int, bonus: int = 0,
|
||||
refunds: int = None
|
||||
):
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
transaction = await cls.create(
|
||||
transactiontype=transaction_type,
|
||||
@@ -127,6 +130,7 @@ class EconomyData(Registry, name='economy'):
|
||||
return transaction
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='execute_transactions')
|
||||
async def execute_transactions(cls, *transactions):
|
||||
"""
|
||||
Execute multiple transactions in one data transaction.
|
||||
@@ -142,7 +146,8 @@ class EconomyData(Registry, name='economy'):
|
||||
if not transactions:
|
||||
return []
|
||||
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# Create the transactions
|
||||
rows = await cls.table.insert_many(
|
||||
@@ -180,10 +185,12 @@ class EconomyData(Registry, name='economy'):
|
||||
return rows
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='refund_transactions')
|
||||
async def refund_transactions(cls, *transactionids, actorid=0):
|
||||
if not transactionids:
|
||||
return []
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# First fetch the transaction rows to refund
|
||||
data = await cls.table.select_where(transactionid=transactionids)
|
||||
@@ -217,12 +224,14 @@ class EconomyData(Registry, name='economy'):
|
||||
itemid = Integer()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='purchase_transaction')
|
||||
async def purchase_transaction(
|
||||
cls,
|
||||
guildid: int, actorid: int,
|
||||
userid: int, itemid: int, amount: int
|
||||
):
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
row = await EconomyData.Transaction.execute_transaction(
|
||||
TransactionType.SHOP_PURCHASE,
|
||||
@@ -263,12 +272,14 @@ class EconomyData(Registry, name='economy'):
|
||||
return result[0]['recent'] or 0
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='reward_completed_tasks')
|
||||
async def reward_completed(cls, userid, guildid, count, amount):
|
||||
"""
|
||||
Reward the specified member `amount` coins for completing `count` tasks.
|
||||
"""
|
||||
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
row = await EconomyData.Transaction.execute_transaction(
|
||||
TransactionType.TASKS,
|
||||
|
||||
@@ -193,7 +193,8 @@ class MemberAdminCog(LionCog):
|
||||
await lion.data.update(last_left=utc_now())
|
||||
|
||||
# Save member roles
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
await self.data.past_roles.delete_where(
|
||||
guildid=member.guild.id,
|
||||
|
||||
@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
|
||||
update_args[instance._column] = instance.data
|
||||
ack_lines.append(instance.update_message)
|
||||
|
||||
await ctx.lguild.data.update(**update_args)
|
||||
|
||||
# Do the ack
|
||||
tick = self.bot.config.emojis.tick
|
||||
embed = discord.Embed(
|
||||
|
||||
@@ -483,8 +483,6 @@ class RoleMenu:
|
||||
)).format(role=role.name)
|
||||
)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# Remove the role
|
||||
try:
|
||||
await member.remove_roles(role)
|
||||
@@ -591,8 +589,6 @@ class RoleMenu:
|
||||
)
|
||||
)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
try:
|
||||
await member.add_roles(role)
|
||||
except discord.Forbidden:
|
||||
|
||||
@@ -259,17 +259,6 @@ class RoomCog(LionCog):
|
||||
lguild,
|
||||
[member.id for member in members]
|
||||
)
|
||||
self._start(room)
|
||||
|
||||
# Send tips message
|
||||
# TODO: Actual tips.
|
||||
await channel.send(
|
||||
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
|
||||
)
|
||||
|
||||
# Send config UI
|
||||
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
|
||||
await ui.send(channel)
|
||||
except Exception:
|
||||
try:
|
||||
await channel.delete(reason="Failed to created private room")
|
||||
@@ -454,7 +443,42 @@ class RoomCog(LionCog):
|
||||
return
|
||||
|
||||
# Positive response. Start a transaction.
|
||||
conn = await self.bot.db.get_connection()
|
||||
room = await self._do_create_room(ctx, required, days, rent, name, provided)
|
||||
|
||||
if room:
|
||||
# Ack with confirmation message pointing to the room
|
||||
msg = t(_p(
|
||||
'cmd:room_rent|success',
|
||||
"Successfully created your private room {channel}!"
|
||||
)).format(channel=room.channel.mention)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
|
||||
description=msg
|
||||
)
|
||||
)
|
||||
self._start(room)
|
||||
|
||||
# Send tips message
|
||||
# TODO: Actual tips.
|
||||
await room.channel.send(
|
||||
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
|
||||
mention=ctx.author.mention
|
||||
)
|
||||
)
|
||||
|
||||
# Send config UI
|
||||
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
|
||||
await ui.send(room.channel)
|
||||
|
||||
@log_wrap(action='create_room')
|
||||
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
|
||||
t = self.bot.translator.t
|
||||
# TODO: Rollback the channel create if this fails
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
# Note that the room creation will go into the UI as well.
|
||||
async with conn.transaction():
|
||||
# Check member balance is sufficient
|
||||
await ctx.alion.data.refresh()
|
||||
@@ -486,7 +510,7 @@ class RoomCog(LionCog):
|
||||
|
||||
# Create room with given starting balance and other parameters
|
||||
try:
|
||||
room = await self.create_private_room(
|
||||
return await self.create_private_room(
|
||||
ctx.guild,
|
||||
ctx.author,
|
||||
required - rent,
|
||||
@@ -519,19 +543,6 @@ class RoomCog(LionCog):
|
||||
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
|
||||
return
|
||||
|
||||
# Ack with confirmation message pointing to the room
|
||||
msg = t(_p(
|
||||
'cmd:room_rent|success',
|
||||
"Successfully created your private room {channel}!"
|
||||
)).format(channel=room.channel.mention)
|
||||
await ctx.reply(
|
||||
embed=discord.Embed(
|
||||
colour=discord.Colour.brand_green(),
|
||||
title=t(_p('cmd:room_rent|success|title', "Private Room Created!")),
|
||||
description=msg
|
||||
)
|
||||
)
|
||||
|
||||
@room_group.command(
|
||||
name=_p('cmd:room_status', "status"),
|
||||
description=_p(
|
||||
@@ -864,8 +875,7 @@ class RoomCog(LionCog):
|
||||
return
|
||||
|
||||
# Start Transaction
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# TODO: Economy transaction
|
||||
await ctx.alion.data.refresh()
|
||||
member_balance = ctx.alion.data.coins
|
||||
if member_balance < coins:
|
||||
@@ -883,7 +893,6 @@ class RoomCog(LionCog):
|
||||
return
|
||||
|
||||
# Deduct balance
|
||||
# TODO: Economy transaction
|
||||
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
|
||||
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
|
||||
|
||||
from meta import LionBot, conf
|
||||
from meta.errors import UserInputError
|
||||
from meta.logger import log_wrap
|
||||
from babel.translator import ctx_locale
|
||||
from utils.lib import utc_now, MessageArgs, error_embed
|
||||
from utils.ui import MessageUI, input
|
||||
@@ -115,8 +116,18 @@ class RoomUI(MessageUI):
|
||||
return
|
||||
await submit.response.defer(thinking=True, ephemeral=True)
|
||||
|
||||
await self._do_deposit(t, press, amount, submit)
|
||||
|
||||
# Post deposit message
|
||||
await self.room.notify_deposit(press.user, amount)
|
||||
|
||||
await self.refresh(thinking=submit)
|
||||
|
||||
@log_wrap(isolate=True)
|
||||
async def _do_deposit(self, t, press, amount, submit):
|
||||
# Start transaction for deposit
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
# Get the lion balance directly
|
||||
lion = await self.bot.core.data.Member.fetch(
|
||||
@@ -143,11 +154,6 @@ class RoomUI(MessageUI):
|
||||
await lion.update(coins=CoreData.Member.coins - amount)
|
||||
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
|
||||
|
||||
# Post deposit message
|
||||
await self.room.notify_deposit(press.user, amount)
|
||||
|
||||
await self.refresh(thinking=submit)
|
||||
|
||||
async def desposit_button_refresh(self):
|
||||
self.desposit_button.label = self.bot.translator.t(_p(
|
||||
'ui:room_status|button:deposit|label',
|
||||
|
||||
@@ -217,7 +217,8 @@ class ScheduleCog(LionCog):
|
||||
for bookingid in bookingids:
|
||||
await self._cancel_booking_active(*bookingid)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
# Now delete from data
|
||||
records = await self.data.ScheduleSessionMember.table.delete_where(
|
||||
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
|
||||
"One or more requested timeslots are already booked!"
|
||||
))
|
||||
raise UserInputError(error)
|
||||
conn = await self.bot.db.get_connection()
|
||||
|
||||
# Booking request is now validated. Perform bookings.
|
||||
# Fetch or create session data
|
||||
@@ -482,8 +482,8 @@ class ScheduleCog(LionCog):
|
||||
*((guildid, slotid) for slotid in slotids)
|
||||
)
|
||||
|
||||
async with conn.transaction():
|
||||
# Create transactions
|
||||
# TODO: wrap in a transaction so the economy transaction gets unwound if it fails
|
||||
economy = self.bot.get_cog('Economy')
|
||||
trans_data = (
|
||||
TransactionType.SCHEDULE_BOOK,
|
||||
|
||||
@@ -356,8 +356,7 @@ class TimeSlot:
|
||||
Does not modify session room channels (responsibility of the next open).
|
||||
"""
|
||||
try:
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# TODO: Transaction?
|
||||
# Calculate rewards
|
||||
rewards = []
|
||||
attendance = []
|
||||
@@ -532,8 +531,7 @@ class TimeSlot:
|
||||
This involves refunding the booking transactions, deleting the booking rows,
|
||||
and updating any messages that may have been posted.
|
||||
"""
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
# TODO: Transaction
|
||||
# Collect booking rows
|
||||
bookings = [member.data for session in sessions for member in session.members.values()]
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
|
||||
|
||||
from meta import LionCog, LionContext, LionBot
|
||||
from meta.errors import SafeCancellation
|
||||
from meta.logger import log_wrap
|
||||
from utils import ui
|
||||
from utils.lib import error_embed
|
||||
from constants import MAX_COINS
|
||||
@@ -145,6 +146,7 @@ class ColourShop(Shop):
|
||||
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
|
||||
]
|
||||
|
||||
@log_wrap(action='purchase')
|
||||
async def purchase(self, itemid) -> ColourRoleItem:
|
||||
"""
|
||||
Atomically handle a purchase of a ColourRoleItem.
|
||||
@@ -157,7 +159,8 @@ class ColourShop(Shop):
|
||||
If the purchase fails for a known reason, raises SafeCancellation, with the error information.
|
||||
"""
|
||||
t = self.bot.translator.t
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
# Retrieve the item to purchase from data
|
||||
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
|
||||
|
||||
@@ -4,6 +4,7 @@ from enum import Enum
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
from meta.logger import log_wrap
|
||||
from data import RowModel, Registry, Table, RegisterEnum
|
||||
from data.columns import Integer, String, Timestamp, Bool, Column
|
||||
|
||||
@@ -80,6 +81,7 @@ class StatsData(Registry):
|
||||
end_time = Timestamp()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='tracked_time_between')
|
||||
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
@@ -103,7 +105,7 @@ class StatsData(Registry):
|
||||
for _ in points
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -112,8 +114,9 @@ class StatsData(Registry):
|
||||
return cursor.fetchall()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='study_time_between')
|
||||
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT study_time_between(%s, %s, %s, %s)",
|
||||
@@ -122,6 +125,7 @@ class StatsData(Registry):
|
||||
return (await cursor.fetchone()[0]) or 0
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='study_times_between')
|
||||
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
|
||||
if len(points) < 2:
|
||||
raise ValueError('Not enough block points given!')
|
||||
@@ -141,7 +145,7 @@ class StatsData(Registry):
|
||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -150,8 +154,9 @@ class StatsData(Registry):
|
||||
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='study_time_since')
|
||||
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT study_time_since(%s, %s, %s)",
|
||||
@@ -160,6 +165,7 @@ class StatsData(Registry):
|
||||
return (await cursor.fetchone()[0]) or 0
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='study_times_between')
|
||||
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
|
||||
if len(starts) < 1:
|
||||
raise ValueError('No starting points given!')
|
||||
@@ -178,7 +184,7 @@ class StatsData(Registry):
|
||||
sql.SQL("({})").format(sql.Placeholder()) for _ in starts
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -187,6 +193,7 @@ class StatsData(Registry):
|
||||
return [r['stime'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='leaderboard_since')
|
||||
async def leaderboard_since(cls, guildid: int, since):
|
||||
"""
|
||||
Return the voice totals since the given time for each member in the guild.
|
||||
@@ -226,7 +233,8 @@ class StatsData(Registry):
|
||||
)
|
||||
second_query_args = (since, guildid, since, since)
|
||||
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
cls._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(second_query, second_query_args)
|
||||
@@ -243,6 +251,7 @@ class StatsData(Registry):
|
||||
return leaderboard
|
||||
|
||||
@classmethod
|
||||
@log_wrap('leaderboard_all')
|
||||
async def leaderboard_all(cls, guildid: int):
|
||||
"""
|
||||
Return the all-time voice totals for the given guild.
|
||||
@@ -257,7 +266,7 @@ class StatsData(Registry):
|
||||
"""
|
||||
)
|
||||
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(query, (guildid, ))
|
||||
leaderboard = [
|
||||
@@ -296,6 +305,7 @@ class StatsData(Registry):
|
||||
transactionid = Integer()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='xp_since')
|
||||
async def xp_since(cls, guildid: int, userid: int, *starts):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
@@ -320,7 +330,7 @@ class StatsData(Registry):
|
||||
sql.Placeholder() for _ in starts
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -329,6 +339,7 @@ class StatsData(Registry):
|
||||
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='xp_between')
|
||||
async def xp_between(cls, guildid: int, userid: int, *points):
|
||||
blocks = zip(points, points[1:])
|
||||
query = sql.SQL(
|
||||
@@ -355,7 +366,7 @@ class StatsData(Registry):
|
||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -364,6 +375,7 @@ class StatsData(Registry):
|
||||
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='leaderboard_since')
|
||||
async def leaderboard_since(cls, guildid: int, since):
|
||||
"""
|
||||
Return the XP totals for the given guild since the given time.
|
||||
@@ -378,7 +390,7 @@ class StatsData(Registry):
|
||||
"""
|
||||
)
|
||||
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(query, (guildid, since))
|
||||
leaderboard = [
|
||||
@@ -388,6 +400,7 @@ class StatsData(Registry):
|
||||
return leaderboard
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='leaderboard_all')
|
||||
async def leaderboard_all(cls, guildid: int):
|
||||
"""
|
||||
Return the all-time XP totals for the given guild.
|
||||
@@ -402,7 +415,7 @@ class StatsData(Registry):
|
||||
"""
|
||||
)
|
||||
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(query, (guildid, ))
|
||||
leaderboard = [
|
||||
@@ -436,6 +449,7 @@ class StatsData(Registry):
|
||||
exp_type: Column[ExpType] = Column()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='user_xp_since')
|
||||
async def xp_since(cls, userid: int, *starts):
|
||||
query = sql.SQL(
|
||||
"""
|
||||
@@ -459,7 +473,7 @@ class StatsData(Registry):
|
||||
sql.Placeholder() for _ in starts
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -468,6 +482,7 @@ class StatsData(Registry):
|
||||
return [r['exp'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='user_xp_since')
|
||||
async def xp_between(cls, userid: int, *points):
|
||||
blocks = zip(points, points[1:])
|
||||
query = sql.SQL(
|
||||
@@ -493,7 +508,7 @@ class StatsData(Registry):
|
||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -531,8 +546,10 @@ class StatsData(Registry):
|
||||
return [tag.tag for tag in tags]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='set_profile_tags')
|
||||
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
|
||||
conn = await self._connector.get_connection()
|
||||
async with self._connector.connection() as conn:
|
||||
self._connector.conn = conn
|
||||
async with conn.transaction():
|
||||
await self.table.delete_where(guildid=guildid, userid=userid)
|
||||
if tags:
|
||||
|
||||
@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
|
||||
# Update the tasklist
|
||||
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
|
||||
modified = True
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self._connector.connection() as conn:
|
||||
async with conn.transaction():
|
||||
await tasks_model.table.delete_where(**key)
|
||||
await tasks_model.table.delete_where(**key).with_connection(conn)
|
||||
if new_tasks:
|
||||
await tasks_model.table.insert_many(
|
||||
(*key.keys(), 'completed', 'content'),
|
||||
*((*key.values(), *new_task) for new_task in new_tasks)
|
||||
)
|
||||
).with_connection(conn)
|
||||
|
||||
if modified:
|
||||
# If either goal type was modified, clear the rendered cache and refresh
|
||||
|
||||
@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
|
||||
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
|
||||
|
||||
from meta import LionBot, LionCog, LionContext
|
||||
from meta.logger import log_wrap
|
||||
from meta.errors import UserInputError
|
||||
from utils.lib import utc_now, error_embed
|
||||
from utils.ui import ChoicedEnum, Transformed, AButton
|
||||
@@ -141,8 +142,10 @@ class TasklistCog(LionCog):
|
||||
self.crossload_group(self.configure_group, configcog.configure_group)
|
||||
|
||||
@LionCog.listener('on_tasks_completed')
|
||||
@log_wrap(action="reward tasks completed")
|
||||
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
|
||||
tasks = await tasklist.fetch_tasks(*taskids)
|
||||
@@ -477,11 +480,8 @@ class TasklistCog(LionCog):
|
||||
# Contents successfully parsed, update the tasklist.
|
||||
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
|
||||
|
||||
# Lazily using the editor because it has a good parser
|
||||
taskinfo = tasklist.parse_tasklist(lines)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with conn.transaction():
|
||||
now = utc_now()
|
||||
|
||||
# Delete tasklist if required
|
||||
|
||||
@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
|
||||
from discord.ui.text_input import TextInput, TextStyle
|
||||
|
||||
from meta import conf
|
||||
from meta.logger import log_wrap
|
||||
from meta.errors import UserInputError
|
||||
from utils.lib import MessageArgs, utc_now
|
||||
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
|
||||
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
|
||||
except UserInputError as error:
|
||||
await ModalRetryUI(self, error.msg).respond_to(interaction)
|
||||
|
||||
@log_wrap(action="parse editor")
|
||||
async def parse_editor(self):
|
||||
# First parse each line
|
||||
new_lines = self.tasklist_editor.value.splitlines()
|
||||
@@ -155,7 +157,8 @@ class BulkEditor(LeoModal):
|
||||
)
|
||||
|
||||
# TODO: Incremental/diff editing
|
||||
conn = await self.bot.db.get_connection()
|
||||
async with self.bot.db.connection() as conn:
|
||||
self.bot.db.conn = conn
|
||||
async with conn.transaction():
|
||||
now = utc_now()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Type
|
||||
import json
|
||||
|
||||
from data import RowModel, Table, ORDER
|
||||
from meta.logger import log_wrap, set_logging_context
|
||||
|
||||
|
||||
class ModelData:
|
||||
@@ -60,6 +61,7 @@ class ModelData:
|
||||
It only updates.
|
||||
"""
|
||||
# TODO: Better way of getting the key?
|
||||
# TODO: Transaction
|
||||
if not isinstance(parent_id, tuple):
|
||||
parent_id = (parent_id, )
|
||||
model = cls._model
|
||||
@@ -83,6 +85,8 @@ class ListData:
|
||||
This assumes the list is the only data stored in the table,
|
||||
and removes list entries by deleting rows.
|
||||
"""
|
||||
setting_id: str
|
||||
|
||||
# Table storing the setting data
|
||||
_table_interface: Table
|
||||
|
||||
@@ -100,10 +104,12 @@ class ListData:
|
||||
_cache = None # Map[id -> value]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(isolate=True)
|
||||
async def _reader(cls, parent_id, use_cache=True, **kwargs):
|
||||
"""
|
||||
Read in all entries associated to the given id.
|
||||
"""
|
||||
set_logging_context(action="Read cls.setting_id")
|
||||
if cls._cache is not None and parent_id in cls._cache and use_cache:
|
||||
return cls._cache[parent_id]
|
||||
|
||||
@@ -121,12 +127,15 @@ class ListData:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
@log_wrap(isolate=True)
|
||||
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
|
||||
"""
|
||||
Write the provided list to storage.
|
||||
"""
|
||||
set_logging_context(action="Write cls.setting_id")
|
||||
table = cls._table_interface
|
||||
conn = await table.connector.get_connection()
|
||||
async with table.connector.connection() as conn:
|
||||
table.connector.conn = conn
|
||||
async with conn.transaction():
|
||||
# Handle None input as an empty list
|
||||
if data is None:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
|
||||
from meta.logger import log_wrap
|
||||
from data import RowModel, Registry, Table
|
||||
from data.columns import Integer, String, Timestamp, Bool
|
||||
|
||||
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
|
||||
member_expid = Integer()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='end_text_sessions')
|
||||
async def end_sessions(cls, connector, *session_data):
|
||||
query = sql.SQL("""
|
||||
WITH
|
||||
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
|
||||
) SELECT
|
||||
data._guildid, 0,
|
||||
NULL, data._userid,
|
||||
SUM(_coins), 0, 'TEXT_SESSION'
|
||||
LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
|
||||
FROM data
|
||||
WHERE data._coins > 0
|
||||
GROUP BY (data._guildid, data._userid)
|
||||
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
|
||||
)
|
||||
, member AS (
|
||||
UPDATE members
|
||||
SET coins = coins + data._coins
|
||||
SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
|
||||
FROM data
|
||||
WHERE members.userid = data._userid AND members.guildid = data._guildid
|
||||
)
|
||||
@@ -166,7 +167,7 @@ class TextTrackerData(Registry):
|
||||
# Or ask for a connection from the connection pool
|
||||
# Transaction may take some time due to index updates
|
||||
# Alternatively maybe use the "do not expect response mode"
|
||||
conn = await connector.get_connection()
|
||||
async with connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -175,6 +176,7 @@ class TextTrackerData(Registry):
|
||||
return
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='user_messages_between')
|
||||
async def user_messages_between(cls, userid: int, *points):
|
||||
"""
|
||||
Compute messages written between the given points.
|
||||
@@ -203,7 +205,7 @@ class TextTrackerData(Registry):
|
||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -212,6 +214,7 @@ class TextTrackerData(Registry):
|
||||
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='member_messages_between')
|
||||
async def member_messages_between(cls, guildid: int, userid: int, *points):
|
||||
"""
|
||||
Compute messages written between the given points.
|
||||
@@ -241,7 +244,7 @@ class TextTrackerData(Registry):
|
||||
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -250,6 +253,7 @@ class TextTrackerData(Registry):
|
||||
return [r['period_m'] or 0 for r in await cursor.fetchall()]
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='member_messages_since')
|
||||
async def member_messages_since(cls, guildid: int, userid: int, *points):
|
||||
"""
|
||||
Compute messages written between the given points.
|
||||
@@ -277,7 +281,7 @@ class TextTrackerData(Registry):
|
||||
sql.SQL("({})").format(sql.Placeholder()) for _ in points
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime as dt
|
||||
from itertools import chain
|
||||
from psycopg import sql
|
||||
|
||||
from meta.logger import log_wrap
|
||||
from data import RowModel, Registry, Table
|
||||
from data.columns import Integer, String, Timestamp, Bool
|
||||
|
||||
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
|
||||
hourly_coins = Integer()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='close_voice_session')
|
||||
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT close_study_session_at(%s, %s, %s)",
|
||||
(guildid, userid, _at)
|
||||
)
|
||||
member_data = await cursor.fetchone()
|
||||
return await cursor.fetchone()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='close_voice_sessions')
|
||||
async def close_voice_sessions_at(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
SELECT
|
||||
@@ -139,7 +142,7 @@ class VoiceTrackerData(Registry):
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -147,11 +150,12 @@ class VoiceTrackerData(Registry):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='update_voice_session')
|
||||
async def update_voice_session_at(
|
||||
cls, guildid: int, userid: int, _at: dt.datetime,
|
||||
stream: bool, video: bool, rate: float
|
||||
) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
|
||||
@@ -161,6 +165,7 @@ class VoiceTrackerData(Registry):
|
||||
return cls._make_rows(*rows)
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='update_voice_sessions')
|
||||
async def update_voice_sessions_at(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
UPDATE
|
||||
@@ -209,7 +214,7 @@ class VoiceTrackerData(Registry):
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
@@ -257,8 +262,9 @@ class VoiceTrackerData(Registry):
|
||||
transactionid = Integer()
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='study_time_since')
|
||||
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"SELECT study_time_since(%s, %s, %s) AS result",
|
||||
@@ -268,6 +274,7 @@ class VoiceTrackerData(Registry):
|
||||
return (result['result'] or 0) if result else 0
|
||||
|
||||
@classmethod
|
||||
@log_wrap(action='multiple_voice_tracked_since')
|
||||
async def multiple_voice_tracked_since(cls, *arg_tuples):
|
||||
query = sql.SQL("""
|
||||
SELECT
|
||||
@@ -286,7 +293,7 @@ class VoiceTrackerData(Registry):
|
||||
for _ in arg_tuples
|
||||
)
|
||||
)
|
||||
conn = await cls._connector.get_connection()
|
||||
async with cls._connector.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
query,
|
||||
|
||||
@@ -161,7 +161,6 @@ class VoiceSession:
|
||||
self.state.channelid, guildid=self.guildid, deleted=False
|
||||
)
|
||||
|
||||
conn = await self.bot.db.get_connection()
|
||||
# Insert an ongoing_session with the correct state, set data
|
||||
state = self.state
|
||||
self.data = await self.registry.VoiceSessionsOngoing.create(
|
||||
|
||||
Reference in New Issue
Block a user