fix (data): Parallel connection pool.

This commit is contained in:
2023-08-23 17:31:38 +03:00
parent 5bca9bca33
commit df9b835cd5
27 changed files with 1175 additions and 1021 deletions

View File

@@ -52,7 +52,7 @@ class EventHandler(Generic[T]):
f"Queue on event handler {self.route_name} is full! Discarding event {data}"
)
@log_wrap(action='consumer', isolate=False)
@log_wrap(action='consumer')
async def consumer(self):
while True:
try:
@@ -76,7 +76,7 @@ class EventHandler(Generic[T]):
)
pass
@log_wrap(action='batch', isolate=False)
@log_wrap(action='batch')
async def process_batch(self):
logger.debug("Processing Batch")
# TODO: copy syntax might be more efficient here

View File

@@ -123,7 +123,7 @@ class AnalyticsServer:
log_action_stack.set(['Analytics'])
log_app.set(conf.analytics['appname'])
async with await self.db.connect():
async with self.db.open():
await self.talk.connect()
await self.attach_event_handlers()
self._snap_task = asyncio.create_task(self.snapshot_loop())

View File

@@ -38,7 +38,7 @@ async def main():
intents.message_content = True
intents.presences = False
async with await db.connect():
async with db.open():
version = await db.version()
if version.version != DATA_VERSION:
error = f"Data model version is {version}, required version is {DATA_VERSION}! Please migrate."

View File

@@ -57,16 +57,14 @@ class CoreCog(LionCog):
async def cog_load(self):
# Fetch (and possibly create) core data rows.
conn = await self.bot.db.get_connection()
async with conn.transaction():
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.shard_data = await self.data.Shard.fetch_or_create(
shardname,
appname=appname,
shard_id=self.bot.shard_id,
shard_count=self.bot.shard_count
)
self.app_config = await self.data.AppConfig.fetch_or_create(appname)
self.bot_config = await self.data.BotConfig.fetch_or_create(appname)
self.shard_data = await self.data.Shard.fetch_or_create(
shardname,
appname=appname,
shard_id=self.bot.shard_id,
shard_count=self.bot.shard_count
)
self.bot.add_listener(self.shard_update_guilds, name='on_guild_join')
self.bot.add_listener(self.shard_update_guilds, name='on_guild_remove')

View File

@@ -5,6 +5,7 @@ from cachetools import TTLCache
import discord
from meta import conf
from meta.logger import log_wrap
from data import Table, Registry, Column, RowModel, RegisterEnum
from data.models import WeakCache
from data.columns import Integer, String, Bool, Timestamp
@@ -287,6 +288,7 @@ class CoreData(Registry, name="core"):
_timestamp = Timestamp()
@classmethod
@log_wrap(action="Add Pending Coins")
async def add_pending(cls, pending: list[tuple[int, int, int]]) -> list['CoreData.Member']:
"""
Safely add pending coins to a list of members.
@@ -316,39 +318,40 @@ class CoreData(Registry, name="core"):
)
)
# TODO: Replace with copy syntax/query?
conn = await cls.table.connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*pending))
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
async with cls.table.connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*pending))
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
@classmethod
@log_wrap(action='get_member_rank')
async def get_member_rank(cls, guildid, userid, untracked):
"""
Get the time and coin ranking for the given member, ignoring the provided untracked members.
"""
conn = await cls.table.connector.get_connection()
async with conn.cursor() as curs:
await curs.execute(
"""
SELECT
time_rank, coin_rank
FROM (
SELECT
userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals
WHERE
guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s
""",
(guildid, tuple(untracked), userid)
)
return (await curs.fetchone()) or (None, None)
async with cls.table.connector.connection() as conn:
async with conn.cursor() as curs:
await curs.execute(
"""
SELECT
time_rank, coin_rank
FROM (
SELECT
userid,
row_number() OVER (ORDER BY total_tracked_time DESC, userid ASC) AS time_rank,
row_number() OVER (ORDER BY total_coins DESC, userid ASC) AS coin_rank
FROM members_totals
WHERE
guildid=%s AND userid NOT IN %s
) AS guild_ranks WHERE userid=%s
""",
(guildid, tuple(untracked), userid)
)
return (await curs.fetchone()) or (None, None)
class LionHook(RowModel):
"""

View File

@@ -1,6 +1,7 @@
# from enum import Enum
from typing import Optional
from psycopg.types.enum import register_enum, EnumInfo
from psycopg import AsyncConnection
from .registry import Attachable, Registry
@@ -23,10 +24,17 @@ class RegisterEnum(Attachable):
connector = registry._conn
if connector is None:
raise ValueError("Cannot initialise without connector!")
connection = await connector.get_connection()
if connection is None:
raise ValueError("Cannot Init without connection.")
info = await EnumInfo.fetch(connection, self.name)
connector.connect_hook(self.connection_hook)
# await connector.refresh_pool()
# The below may be somewhat dangerous
# But adaption should never write to the database
await connector.map_over_pool(self.connection_hook)
# if conn := connector.conn:
# # Ensure the adaption is run in the current context as well
# await self.connection_hook(conn)
async def connection_hook(self, conn: AsyncConnection):
info = await EnumInfo.fetch(conn, self.name)
if info is None:
raise ValueError(f"Enum {self.name} not found in database.")
register_enum(info, connection, self.enum, mapping=list(self.mapping.items()))
register_enum(info, conn, self.enum, mapping=list(self.mapping.items()))

View File

@@ -1,7 +1,10 @@
from typing import Protocol, runtime_checkable, Callable, Awaitable
from typing import Protocol, runtime_checkable, Callable, Awaitable, Optional
import logging
from contextvars import ContextVar
from contextlib import asynccontextmanager
import psycopg as psq
from psycopg_pool import AsyncConnectionPool
from psycopg.pq import TransactionStatus
from .cursor import AsyncLoggingCursor
@@ -10,42 +13,110 @@ logger = logging.getLogger(__name__)
row_factory = psq.rows.dict_row
ctx_connection: Optional[ContextVar[psq.AsyncConnection]] = ContextVar('connection', default=None)
class Connector:
cursor_factory = AsyncLoggingCursor
def __init__(self, conn_args):
self._conn_args = conn_args
self.conn: psq.AsyncConnection = None
self._conn_kwargs = dict(autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory)
self.pool = self.make_pool()
self.conn_hooks = []
async def get_connection(self) -> psq.AsyncConnection:
@property
def conn(self) -> Optional[psq.AsyncConnection]:
"""
Get the current active connection.
This should never be cached outside of a transaction.
Convenience property for the current context connection.
"""
# TODO: Reconnection logic?
if not self.conn:
raise ValueError("Attempting to get connection before initialisation!")
if self.conn.info.transaction_status is TransactionStatus.INERROR:
await self.connect()
logger.error(
"Database connection transaction failed!! This should not happen. Reconnecting."
)
return self.conn
return ctx_connection.get()
async def connect(self) -> psq.AsyncConnection:
logger.info("Establishing connection to database.", extra={'action': "Data Connect"})
self.conn = await psq.AsyncConnection.connect(
self._conn_args, autocommit=True, row_factory=row_factory, cursor_factory=self.cursor_factory
@conn.setter
def conn(self, conn: psq.AsyncConnection):
"""
Set the contextual connection in the current context.
Always do this in an isolated context!
"""
ctx_connection.set(conn)
def make_pool(self) -> AsyncConnectionPool:
logger.info("Initialising connection pool.", extra={'action': "Pool Init"})
return AsyncConnectionPool(
self._conn_args,
open=False,
min_size=4,
max_size=8,
configure=self._setup_connection,
kwargs=self._conn_kwargs
)
for hook in self.conn_hooks:
await hook(self.conn)
return self.conn
async def reconnect(self) -> psq.AsyncConnection:
return await self.connect()
async def refresh_pool(self):
"""
Refresh the pool.
The point of this is to invalidate any existing connections so that the connection set up is run again.
Better ways should be sought (a way to
"""
logger.info("Pool refresh requested, closing and reopening.")
old_pool = self.pool
self.pool = self.make_pool()
await self.pool.open()
logger.info(f"Old pool statistics: {self.pool.get_stats()}")
await old_pool.close()
logger.info("Pool refresh complete.")
async def map_over_pool(self, callable):
"""
Dangerous method to call a method on each connection in the pool.
Utilises private methods of the AsyncConnectionPool.
"""
async with self.pool._lock:
conns = list(self.pool._pool)
while conns:
conn = conns.pop()
try:
await callable(conn)
except Exception:
logger.exception(f"Mapped connection task failed. {callable.__name__}")
@asynccontextmanager
async def open(self):
try:
logger.info("Opening database pool.")
await self.pool.open()
yield
finally:
# May be a different pool!
logger.info(f"Closing database pool. Pool statistics: {self.pool.get_stats()}")
await self.pool.close()
@asynccontextmanager
async def connection(self) -> psq.AsyncConnection:
"""
Asynchronous context manager to get and manage a connection.
If the context connection is set, uses this and does not manage the lifetime.
Otherwise, requests a new connection from the pool and returns it when done.
"""
logger.debug("Database connection requested.", extra={'action': "Data Connect"})
if (conn := self.conn):
yield conn
else:
async with self.pool.connection() as conn:
yield conn
async def _setup_connection(self, conn: psq.AsyncConnection):
logger.debug("Initialising new connection.", extra={'action': "Conn Init"})
for hook in self.conn_hooks:
try:
await hook(conn)
except Exception:
logger.exception("Exception encountered setting up new connection")
return conn
def connect_hook(self, coro: Callable[[psq.AsyncConnection], Awaitable[None]]):
"""

View File

@@ -35,12 +35,13 @@ class Database(Connector):
"""
Return the current schema version as a Version namedtuple.
"""
async with self.conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)
async with self.connection() as conn:
async with conn.cursor() as cursor:
# Get last entry in version table, compare against desired version
await cursor.execute("SELECT * FROM VersionHistory ORDER BY time DESC LIMIT 1")
row = await cursor.fetchone()
if row:
return Version(row['version'], row['time'], row['author'])
else:
# No versions in the database
return Version(-1, None, None)

View File

@@ -101,12 +101,12 @@ class Query(Generic[QueryResult]):
if self.connector is None:
raise ValueError("Cannot execute query without cursor, connection, or connector.")
else:
conn = await self.connector.get_connection()
async with self.connector.connection() as conn:
async with conn.cursor() as cursor:
data = await self._execute(cursor)
else:
conn = self.conn
async with conn.cursor() as cursor:
data = await self._execute(cursor)
async with self.conn.cursor() as cursor:
data = await self._execute(cursor)
else:
data = await self._execute(cursor)
return data

View File

@@ -1,4 +1,5 @@
from typing import Optional, Union
import asyncio
import discord
from discord.ext import commands as cmds
@@ -182,30 +183,34 @@ class Economy(LionCog):
# We may need to do a mass row create operation.
targetids = set(target.id for target in targets)
if len(targets) > 1:
conn = await ctx.bot.db.get_connection()
async with conn.transaction():
# First fetch the members which currently exist
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
query.select('userid').with_no_adapter()
if 2 * len(targets) < len(ctx.guild.members):
# More efficient to fetch the targets explicitly
query.where(userid=list(targetids))
existent_rows = await query
existentids = set(r['userid'] for r in existent_rows)
async def wrapper():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# First fetch the members which currently exist
query = self.bot.core.data.Member.table.select_where(guildid=ctx.guild.id)
query.select('userid').with_no_adapter()
if 2 * len(targets) < len(ctx.guild.members):
# More efficient to fetch the targets explicitly
query.where(userid=list(targetids))
existent_rows = await query
existentids = set(r['userid'] for r in existent_rows)
# Then check if any new userids need adding, and if so create them
new_ids = targetids.difference(existentids)
if new_ids:
# We use ON CONFLICT IGNORE here in case the users already exist.
await self.bot.core.data.User.table.insert_many(
('userid',),
*((id,) for id in new_ids)
).on_conflict(ignore=True)
# TODO: Replace 0 here with the starting_coin value
await self.bot.core.data.Member.table.insert_many(
('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True)
# Then check if any new userids need adding, and if so create them
new_ids = targetids.difference(existentids)
if new_ids:
# We use ON CONFLICT IGNORE here in case the users already exist.
await self.bot.core.data.User.table.insert_many(
('userid',),
*((id,) for id in new_ids)
).on_conflict(ignore=True)
# TODO: Replace 0 here with the starting_coin value
await self.bot.core.data.Member.table.insert_many(
('guildid', 'userid', 'coins'),
*((ctx.guild.id, id, 0) for id in new_ids)
).on_conflict(ignore=True)
task = asyncio.create_task(wrapper(), name="wrapped-create-members")
await task
else:
# With only one target, we can take a simpler path, and make better use of local caches.
await self.bot.core.lions.fetch_member(ctx.guild.id, target.id)
@@ -703,31 +708,34 @@ class Economy(LionCog):
# Alternative flow could be waiting until the target user presses accept
await ctx.interaction.response.defer(thinking=True, ephemeral=True)
conn = await self.bot.db.get_connection()
async with conn.transaction():
# We do this in a transaction so that if something goes wrong,
# the coins deduction is rolled back atomicly
balance = ctx.alion.data.coins
if amount > balance:
await ctx.interaction.edit_original_response(
embed=error_embed(
t(_p(
'cmd:send|error:insufficient',
"You do not have enough lioncoins to do this!\n"
"`Current Balance:` {coin_emoji}{balance}"
)).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
balance=balance
async def wrapped():
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# We do this in a transaction so that if something goes wrong,
# the coins deduction is rolled back atomicly
balance = ctx.alion.data.coins
if amount > balance:
await ctx.interaction.edit_original_response(
embed=error_embed(
t(_p(
'cmd:send|error:insufficient',
"You do not have enough lioncoins to do this!\n"
"`Current Balance:` {coin_emoji}{balance}"
)).format(
coin_emoji=self.bot.config.emojis.getemoji('coin'),
balance=balance
)
),
)
),
)
return
return
# Transfer the coins
await ctx.alion.data.update(coins=(Member.coins - amount))
await target_lion.data.update(coins=(Member.coins + amount))
# Transfer the coins
await ctx.alion.data.update(coins=(Member.coins - amount))
await target_lion.data.update(coins=(Member.coins + amount))
# TODO: Audit trail
await asyncio.create_task(wrapped(), name="wrapped-send")
# Message target
embed = discord.Embed(

View File

@@ -1,6 +1,7 @@
from enum import Enum
from psycopg import sql
from meta.logger import log_wrap
from data import Registry, RowModel, RegisterEnum, JOINTYPE, RawExpr
from data.columns import Integer, Bool, Column, Timestamp
from core.data import CoreData
@@ -101,6 +102,7 @@ class EconomyData(Registry, name='economy'):
created_at = Timestamp()
@classmethod
@log_wrap(action='execute_transaction')
async def execute_transaction(
cls,
transaction_type: TransactionType,
@@ -108,25 +110,27 @@ class EconomyData(Registry, name='economy'):
from_account: int, to_account: int, amount: int, bonus: int = 0,
refunds: int = None
):
conn = await cls._connector.get_connection()
async with conn.transaction():
transaction = await cls.create(
transactiontype=transaction_type,
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
from_account=from_account, to_account=to_account,
refunds=refunds
)
if from_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=from_account
).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus)))
if to_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=to_account
).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus)))
return transaction
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
transaction = await cls.create(
transactiontype=transaction_type,
guildid=guildid, actorid=actorid, amount=amount, bonus=bonus,
from_account=from_account, to_account=to_account,
refunds=refunds
)
if from_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=from_account
).set(coins=SAFECOINS(CoreData.Member.coins - (amount + bonus)))
if to_account is not None:
await CoreData.Member.table.update_where(
guildid=guildid, userid=to_account
).set(coins=SAFECOINS(CoreData.Member.coins + (amount + bonus)))
return transaction
@classmethod
@log_wrap(action='execute_transactions')
async def execute_transactions(cls, *transactions):
"""
Execute multiple transactions in one data transaction.
@@ -142,65 +146,68 @@ class EconomyData(Registry, name='economy'):
if not transactions:
return []
conn = await cls._connector.get_connection()
async with conn.transaction():
# Create the transactions
rows = await cls.table.insert_many(
(
'transactiontype',
'guildid', 'actorid',
'from_account', 'to_account',
'amount', 'bonus',
'refunds'
),
*transactions
).with_adapter(cls._make_rows)
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# Create the transactions
rows = await cls.table.insert_many(
(
'transactiontype',
'guildid', 'actorid',
'from_account', 'to_account',
'amount', 'bonus',
'refunds'
),
*transactions
).with_adapter(cls._make_rows)
# Update the members
transtable = TemporaryTable(
'_guildid', '_userid', '_amount',
types=('BIGINT', 'BIGINT', 'INTEGER')
)
values = transtable.values
for transaction in transactions:
_, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction
coins = amount + bonus
if coins:
if from_acc:
values.append((guildid, from_acc, -1 * coins))
if to_acc:
values.append((guildid, to_acc, coins))
if values:
Member = CoreData.Member
await Member.table.update_where(
guildid=transtable['_guildid'], userid=transtable['_userid']
).set(
coins=SAFECOINS(Member.coins + transtable['_amount'])
).from_expr(transtable)
# Update the members
transtable = TemporaryTable(
'_guildid', '_userid', '_amount',
types=('BIGINT', 'BIGINT', 'INTEGER')
)
values = transtable.values
for transaction in transactions:
_, guildid, _, from_acc, to_acc, amount, bonus, _ = transaction
coins = amount + bonus
if coins:
if from_acc:
values.append((guildid, from_acc, -1 * coins))
if to_acc:
values.append((guildid, to_acc, coins))
if values:
Member = CoreData.Member
await Member.table.update_where(
guildid=transtable['_guildid'], userid=transtable['_userid']
).set(
coins=SAFECOINS(Member.coins + transtable['_amount'])
).from_expr(transtable)
return rows
@classmethod
@log_wrap(action='refund_transactions')
async def refund_transactions(cls, *transactionids, actorid=0):
if not transactionids:
return []
conn = await cls._connector.get_connection()
async with conn.transaction():
# First fetch the transaction rows to refund
data = await cls.table.select_where(transactionid=transactionids)
if data:
# Build the transaction refund data
records = [
(
TransactionType.REFUND,
tr['guildid'], actorid,
tr['to_account'], tr['from_account'],
tr['amount'] + tr['bonus'], 0,
tr['transactionid']
)
for tr in data
]
# Execute refund transactions
return await cls.execute_transactions(*records)
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
# First fetch the transaction rows to refund
data = await cls.table.select_where(transactionid=transactionids)
if data:
# Build the transaction refund data
records = [
(
TransactionType.REFUND,
tr['guildid'], actorid,
tr['to_account'], tr['from_account'],
tr['amount'] + tr['bonus'], 0,
tr['transactionid']
)
for tr in data
]
# Execute refund transactions
return await cls.execute_transactions(*records)
class ShopTransaction(RowModel):
"""
@@ -217,19 +224,21 @@ class EconomyData(Registry, name='economy'):
itemid = Integer()
@classmethod
@log_wrap(action='purchase_transaction')
async def purchase_transaction(
cls,
guildid: int, actorid: int,
userid: int, itemid: int, amount: int
):
conn = await cls._connector.get_connection()
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.SHOP_PURCHASE,
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
amount=amount
)
return await cls.create(transactionid=row.transactionid, itemid=itemid)
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.SHOP_PURCHASE,
guildid=guildid, actorid=actorid, from_account=userid, to_account=None,
amount=amount
)
return await cls.create(transactionid=row.transactionid, itemid=itemid)
class TaskTransaction(RowModel):
"""
@@ -263,19 +272,21 @@ class EconomyData(Registry, name='economy'):
return result[0]['recent'] or 0
@classmethod
@log_wrap(action='reward_completed_tasks')
async def reward_completed(cls, userid, guildid, count, amount):
"""
Reward the specified member `amount` coins for completing `count` tasks.
"""
# TODO: Bonus logic, perhaps apply_bonus(amount), or put this method in the economy cog?
conn = await cls._connector.get_connection()
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.TASKS,
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
amount=amount
)
return await cls.create(transactionid=row.transactionid, count=count)
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
row = await EconomyData.Transaction.execute_transaction(
TransactionType.TASKS,
guildid=guildid, actorid=userid, from_account=None, to_account=userid,
amount=amount
)
return await cls.create(transactionid=row.transactionid, count=count)
class SessionTransaction(RowModel):
"""

View File

@@ -193,18 +193,19 @@ class MemberAdminCog(LionCog):
await lion.data.update(last_left=utc_now())
# Save member roles
conn = await self.bot.db.get_connection()
async with conn.transaction():
await self.data.past_roles.delete_where(
guildid=member.guild.id,
userid=member.id
)
# Insert current member roles
if member.roles:
await self.data.past_roles.insert_many(
('guildid', 'userid', 'roleid'),
*((member.guild.id, member.id, role.id) for role in member.roles)
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
await self.data.past_roles.delete_where(
guildid=member.guild.id,
userid=member.id
)
# Insert current member roles
if member.roles:
await self.data.past_roles.insert_many(
('guildid', 'userid', 'roleid'),
*((member.guild.id, member.id, role.id) for role in member.roles)
)
logger.debug(
f"Stored persisting roles for member <uid:{member.id}> in <gid:{member.guild.id}>."
)

View File

@@ -190,6 +190,8 @@ class ModerationCog(LionCog):
update_args[instance._column] = instance.data
ack_lines.append(instance.update_message)
await ctx.lguild.data.update(**update_args)
# Do the ack
tick = self.bot.config.emojis.tick
embed = discord.Embed(

View File

@@ -483,65 +483,63 @@ class RoleMenu:
)).format(role=role.name)
)
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Remove the role
try:
await member.remove_roles(role)
except discord.Forbidden:
raise UserInputError(
t(_p(
'rolemenu|deselect|error:perms',
"I don't have enough permissions to remove this role from you!"
))
)
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|deselect|error:discord',
"An unknown error occurred removing your role! Please try again later."
))
)
# Update history
now = utc_now()
history = await self.cog.data.RoleMenuHistory.table.update_where(
menuid=self.data.menuid,
roleid=role.id,
userid=member.id,
removed_at=None,
).set(removed_at=now)
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
# Refund if required
transactionids = [row['transactionid'] for row in history]
if self.config.refunds.value and any(transactionids):
transactionids = [tid for tid in transactionids if tid]
economy: Economy = self.bot.get_cog('Economy')
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
total_refund = sum(row.amount + row.bonus for row in refunded)
else:
total_refund = 0
# Ack the removal
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'rolemenu|deslect|success|title',
"Role removed"
# Remove the role
try:
await member.remove_roles(role)
except discord.Forbidden:
raise UserInputError(
t(_p(
'rolemenu|deselect|error:perms',
"I don't have enough permissions to remove this role from you!"
))
)
if total_refund:
embed.description = t(_p(
'rolemenu|deselect|success:refund|desc',
"You have removed **{role}**, and been refunded {coin} **{amount}**."
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund)
else:
embed.description = t(_p(
'rolemenu|deselect|success:norefund|desc',
"You have unequipped **{role}**."
)).format(role=role.name)
return embed
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|deselect|error:discord',
"An unknown error occurred removing your role! Please try again later."
))
)
# Update history
now = utc_now()
history = await self.cog.data.RoleMenuHistory.table.update_where(
menuid=self.data.menuid,
roleid=role.id,
userid=member.id,
removed_at=None,
).set(removed_at=now)
await self.cog.cancel_expiring_tasks(*(row['equipid'] for row in history))
# Refund if required
transactionids = [row['transactionid'] for row in history]
if self.config.refunds.value and any(transactionids):
transactionids = [tid for tid in transactionids if tid]
economy: Economy = self.bot.get_cog('Economy')
refunded = await economy.data.Transaction.refund_transactions(*transactionids)
total_refund = sum(row.amount + row.bonus for row in refunded)
else:
total_refund = 0
# Ack the removal
embed = discord.Embed(
colour=discord.Colour.brand_green(),
title=t(_p(
'rolemenu|deslect|success|title',
"Role removed"
))
)
if total_refund:
embed.description = t(_p(
'rolemenu|deselect|success:refund|desc',
"You have removed **{role}**, and been refunded {coin} **{amount}**."
)).format(role=role.name, coin=self.bot.config.emojis.coin, amount=total_refund)
else:
embed.description = t(_p(
'rolemenu|deselect|success:norefund|desc',
"You have unequipped **{role}**."
)).format(role=role.name)
return embed
else:
# Member does not have the role, selection case.
required = self.config.required_role.data
@@ -591,57 +589,55 @@ class RoleMenu:
)
)
conn = await self.bot.db.get_connection()
async with conn.transaction():
try:
await member.add_roles(role)
except discord.Forbidden:
raise UserInputError(
t(_p(
'rolemenu|select|error:perms',
"I don't have enough permissions to give you this role!"
))
)
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|select|error:discord',
"An unknown error occurred while assigning your role! "
"Please try again later."
))
)
now = utc_now()
# Create transaction if applicable
if price:
economy: Economy = self.bot.get_cog('Economy')
tx = await economy.data.Transaction.execute_transaction(
transaction_type=TransactionType.OTHER,
guildid=guild.id, actorid=member.id,
from_account=member.id, to_account=None,
amount=price
)
tid = tx.transactionid
else:
tid = None
# Calculate expiry
duration = mrole.config.duration.value
if duration is not None:
expiry = now + dt.timedelta(seconds=duration)
else:
expiry = None
# Add to equip history
equip = await self.cog.data.RoleMenuHistory.create(
menuid=self.data.menuid, roleid=role.id,
userid=member.id,
obtained_at=now,
transactionid=tid,
expires_at=expiry
try:
await member.add_roles(role)
except discord.Forbidden:
raise UserInputError(
t(_p(
'rolemenu|select|error:perms',
"I don't have enough permissions to give you this role!"
))
)
await self.cog.schedule_expiring(equip)
except discord.HTTPException:
raise UserInputError(
t(_p(
'rolemenu|select|error:discord',
"An unknown error occurred while assigning your role! "
"Please try again later."
))
)
now = utc_now()
# Create transaction if applicable
if price:
economy: Economy = self.bot.get_cog('Economy')
tx = await economy.data.Transaction.execute_transaction(
transaction_type=TransactionType.OTHER,
guildid=guild.id, actorid=member.id,
from_account=member.id, to_account=None,
amount=price
)
tid = tx.transactionid
else:
tid = None
# Calculate expiry
duration = mrole.config.duration.value
if duration is not None:
expiry = now + dt.timedelta(seconds=duration)
else:
expiry = None
# Add to equip history
equip = await self.cog.data.RoleMenuHistory.create(
menuid=self.data.menuid, roleid=role.id,
userid=member.id,
obtained_at=now,
transactionid=tid,
expires_at=expiry
)
await self.cog.schedule_expiring(equip)
# Ack the selection
embed = discord.Embed(

View File

@@ -259,17 +259,6 @@ class RoomCog(LionCog):
lguild,
[member.id for member in members]
)
self._start(room)
# Send tips message
# TODO: Actual tips.
await channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(mention=owner.mention)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=owner.id, timeout=None)
await ui.send(channel)
except Exception:
try:
await channel.delete(reason="Failed to created private room")
@@ -454,71 +443,9 @@ class RoomCog(LionCog):
return
# Positive response. Start a transaction.
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Check member balance is sufficient
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < required:
await ctx.reply(
embed=error_embed(
t(_np(
'cmd:room_rent|error:insufficient_funds',
"Renting a private room for `one` day costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
"Renting a private room for `{days}` days costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
days
)).format(
coin=self.bot.config.emojis.coin,
balance=member_balance,
required=required,
days=days
),
ephemeral=True
)
)
return
# Deduct balance
# TODO: Economy transaction instead of manual deduction
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
# Create room with given starting balance and other parameters
try:
room = await self.create_private_room(
ctx.guild,
ctx.author,
required - rent,
name or ctx.author.display_name,
members=provided
)
except discord.Forbidden:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:my_permissions',
"Could not create your private room! You were not charged.\n"
"I have insufficient permissions to create a private room channel."
)),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
except discord.HTTPException as e:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:unknown',
"Could not create your private room! You were not charged.\n"
"An unknown error occurred while creating your private room.\n"
"`{error}`"
)).format(error=e.text),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
room = await self._do_create_room(ctx, required, days, rent, name, provided)
if room:
# Ack with confirmation message pointing to the room
msg = t(_p(
'cmd:room_rent|success',
@@ -531,6 +458,90 @@ class RoomCog(LionCog):
description=msg
)
)
self._start(room)
# Send tips message
# TODO: Actual tips.
await room.channel.send(
"{mention} welcome to your private room! You may use the menu below to configure it.".format(
mention=ctx.author.mention
)
)
# Send config UI
ui = RoomUI(self.bot, room, callerid=ctx.author.id, timeout=None)
await ui.send(room.channel)
@log_wrap(action='create_room')
async def _do_create_room(self, ctx, required, days, rent, name, provided) -> Room:
t = self.bot.translator.t
# TODO: Rollback the channel create if this fails
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
# Note that the room creation will go into the UI as well.
async with conn.transaction():
# Check member balance is sufficient
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < required:
await ctx.reply(
embed=error_embed(
t(_np(
'cmd:room_rent|error:insufficient_funds',
"Renting a private room for `one` day costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
"Renting a private room for `{days}` days costs {coin}**{required}**, "
"but you only have {coin}**{balance}**!",
days
)).format(
coin=self.bot.config.emojis.coin,
balance=member_balance,
required=required,
days=days
),
ephemeral=True
)
)
return
# Deduct balance
# TODO: Economy transaction instead of manual deduction
await ctx.alion.data.update(coins=CoreData.Member.coins - required)
# Create room with given starting balance and other parameters
try:
return await self.create_private_room(
ctx.guild,
ctx.author,
required - rent,
name or ctx.author.display_name,
members=provided
)
except discord.Forbidden:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:my_permissions',
"Could not create your private room! You were not charged.\n"
"I have insufficient permissions to create a private room channel."
)),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
except discord.HTTPException as e:
await ctx.reply(
embed=error_embed(
t(_p(
'cmd:room_rent|error:unknown',
"Could not create your private room! You were not charged.\n"
"An unknown error occurred while creating your private room.\n"
"`{error}`"
)).format(error=e.text),
)
)
await ctx.alion.data.update(coins=CoreData.Member.coins + required)
return
@room_group.command(
name=_p('cmd:room_status', "status"),
@@ -864,43 +875,41 @@ class RoomCog(LionCog):
return
# Start Transaction
conn = await self.bot.db.get_connection()
async with conn.transaction():
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < coins:
await ctx.reply(
embed=error_embed(t(_p(
'cmd:room_deposit|error:insufficient_funds',
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
)).format(
coin=self.bot.config.emojis.coin,
amount=coins,
balance=member_balance
)),
ephemeral=True
)
return
# TODO: Economy transaction
await ctx.alion.data.refresh()
member_balance = ctx.alion.data.coins
if member_balance < coins:
await ctx.reply(
embed=error_embed(t(_p(
'cmd:room_deposit|error:insufficient_funds',
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
)).format(
coin=self.bot.config.emojis.coin,
amount=coins,
balance=member_balance
)),
ephemeral=True
)
return
# Deduct balance
# TODO: Economy transaction
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
# Deduct balance
await ctx.alion.data.update(coins=CoreData.Member.coins - coins)
await room.data.update(coin_balance=RoomData.Room.coin_balance + coins)
# Post deposit message
await room.notify_deposit(ctx.author, coins)
# Post deposit message
await room.notify_deposit(ctx.author, coins)
# Ack the deposit
if ctx.channel.id != room.data.channelid:
ack_msg = t(_p(
'cmd:room_depost|success',
"Success! You have contributed {coin}**{amount}** to the private room bank."
)).format(coin=self.bot.config.emojis.coin, amount=coins)
await ctx.reply(
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg)
)
else:
await ctx.interaction.delete_original_response()
# Ack the deposit
if ctx.channel.id != room.data.channelid:
ack_msg = t(_p(
'cmd:room_depost|success',
"Success! You have contributed {coin}**{amount}** to the private room bank."
)).format(coin=self.bot.config.emojis.coin, amount=coins)
await ctx.reply(
embed=discord.Embed(colour=discord.Colour.brand_green(), description=ack_msg)
)
else:
await ctx.interaction.delete_original_response()
# ----- Guild Configuration -----
@LionCog.placeholder_group

View File

@@ -7,6 +7,7 @@ from discord.ui.select import select, UserSelect
from meta import LionBot, conf
from meta.errors import UserInputError
from meta.logger import log_wrap
from babel.translator import ctx_locale
from utils.lib import utc_now, MessageArgs, error_embed
from utils.ui import MessageUI, input
@@ -115,38 +116,43 @@ class RoomUI(MessageUI):
return
await submit.response.defer(thinking=True, ephemeral=True)
await self._do_deposit(t, press, amount, submit)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
@log_wrap(isolate=True)
async def _do_deposit(self, t, press, amount, submit):
# Start transaction for deposit
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Get the lion balance directly
lion = await self.bot.core.data.Member.fetch(
self.room.data.guildid,
press.user.id,
cached=False
)
balance = lion.coins
if balance < amount:
await submit.edit_original_response(
embed=error_embed(
t(_p(
'ui:room_status|button:deposit|error:insufficient_funds',
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
)).format(
coin=self.bot.config.emojis.coin,
amount=amount,
balance=balance
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Get the lion balance directly
lion = await self.bot.core.data.Member.fetch(
self.room.data.guildid,
press.user.id,
cached=False
)
balance = lion.coins
if balance < amount:
await submit.edit_original_response(
embed=error_embed(
t(_p(
'ui:room_status|button:deposit|error:insufficient_funds',
"You cannot deposit {coin}**{amount}**! You only have {coin}**{balance}**."
)).format(
coin=self.bot.config.emojis.coin,
amount=amount,
balance=balance
)
)
)
)
return
# TODO: Economy Transaction
await lion.update(coins=CoreData.Member.coins - amount)
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
# Post deposit message
await self.room.notify_deposit(press.user, amount)
await self.refresh(thinking=submit)
return
# TODO: Economy Transaction
await lion.update(coins=CoreData.Member.coins - amount)
await self.room.data.update(coin_balance=RoomData.Room.coin_balance + amount)
async def desposit_button_refresh(self):
self.desposit_button.label = self.bot.translator.t(_p(

View File

@@ -217,23 +217,24 @@ class ScheduleCog(LionCog):
for bookingid in bookingids:
await self._cancel_booking_active(*bookingid)
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Now delete from data
records = await self.data.ScheduleSessionMember.table.delete_where(
MULTIVALUE_IN(
('slotid', 'guildid', 'userid'),
*bookingids
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Now delete from data
records = await self.data.ScheduleSessionMember.table.delete_where(
MULTIVALUE_IN(
('slotid', 'guildid', 'userid'),
*bookingids
)
)
)
# Refund cancelled bookings
if refund:
maybe_tids = (record['book_transactionid'] for record in records)
tids = [tid for tid in maybe_tids if tid is not None]
if tids:
economy = self.bot.get_cog('Economy')
await economy.data.Transaction.refund_transactions(*tids)
# Refund cancelled bookings
if refund:
maybe_tids = (record['book_transactionid'] for record in records)
tids = [tid for tid in maybe_tids if tid is not None]
if tids:
economy = self.bot.get_cog('Economy')
await economy.data.Transaction.refund_transactions(*tids)
finally:
for lock in locks:
lock.release()
@@ -473,7 +474,6 @@ class ScheduleCog(LionCog):
"One or more requested timeslots are already booked!"
))
raise UserInputError(error)
conn = await self.bot.db.get_connection()
# Booking request is now validated. Perform bookings.
# Fetch or create session data
@@ -482,27 +482,27 @@ class ScheduleCog(LionCog):
*((guildid, slotid) for slotid in slotids)
)
async with conn.transaction():
# Create transactions
economy = self.bot.get_cog('Economy')
trans_data = (
TransactionType.SCHEDULE_BOOK,
guildid, userid, userid, 0,
config.get(ScheduleSettings.ScheduleCost.setting_id).value,
0, None
)
transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids))
transactionids = [row.transactionid for row in transactions]
# Create transactions
# TODO: wrap in a transaction so the economy transaction gets unwound if it fails
economy = self.bot.get_cog('Economy')
trans_data = (
TransactionType.SCHEDULE_BOOK,
guildid, userid, userid, 0,
config.get(ScheduleSettings.ScheduleCost.setting_id).value,
0, None
)
transactions = await economy.data.Transaction.execute_transactions(*(trans_data for _ in slotids))
transactionids = [row.transactionid for row in transactions]
# Create bookings
now = utc_now()
booking_data = await self.data.ScheduleSessionMember.table.insert_many(
('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'),
*(
(guildid, userid, slotid, now, tid)
for slotid, tid in zip(slotids, transactionids)
)
# Create bookings
now = utc_now()
booking_data = await self.data.ScheduleSessionMember.table.insert_many(
('guildid', 'userid', 'slotid', 'booked_at', 'book_transactionid'),
*(
(guildid, userid, slotid, now, tid)
for slotid, tid in zip(slotids, transactionids)
)
)
# Now pass to activated slots
for record in booking_data:

View File

@@ -356,77 +356,76 @@ class TimeSlot:
Does not modify session room channels (responsibility of the next open).
"""
try:
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Calculate rewards
rewards = []
attendance = []
did_not_show = []
for session in sessions:
bonus = session.bonus_reward * session.all_attended
reward = session.attended_reward + bonus
required = session.min_attendence
for member in session.members.values():
guildid = member.guildid
userid = member.userid
attended = (member.total_clock >= required)
if attended:
rewards.append(
(TransactionType.SCHEDULE_REWARD,
guildid, self.bot.user.id,
0, userid,
reward, 0,
None)
)
else:
did_not_show.append((guildid, userid))
attendance.append(
(self.slotid, guildid, userid, attended, member.total_clock)
# TODO: Transaction?
# Calculate rewards
rewards = []
attendance = []
did_not_show = []
for session in sessions:
bonus = session.bonus_reward * session.all_attended
reward = session.attended_reward + bonus
required = session.min_attendence
for member in session.members.values():
guildid = member.guildid
userid = member.userid
attended = (member.total_clock >= required)
if attended:
rewards.append(
(TransactionType.SCHEDULE_REWARD,
guildid, self.bot.user.id,
0, userid,
reward, 0,
None)
)
else:
did_not_show.append((guildid, userid))
# Perform economy transactions
economy: Economy = self.bot.get_cog('Economy')
transactions = await economy.data.Transaction.execute_transactions(*rewards)
reward_ids = {
(t.guildid, t.to_account): t.transactionid
for t in transactions
}
# Update lobby messages
message_tasks = [
asyncio.create_task(session.update_status(save=False))
for session in sessions
if session.lobby_channel is not None
]
await asyncio.gather(*message_tasks)
# Save attendance
if attendance:
att_table = TemporaryTable(
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
attendance.append(
(self.slotid, guildid, userid, attended, member.total_clock)
)
att_table.values = [
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
for sid, gid, uid, att, clock in attendance
]
await self.data.ScheduleSessionMember.table.update_where(
slotid=att_table['_sid'],
guildid=att_table['_gid'],
userid=att_table['_uid'],
).set(
attended=att_table['_att'],
clock=att_table['_clock'],
reward_transactionid=att_table['_reward']
).from_expr(att_table)
# Mark guild sessions as closed
if sessions:
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid,
guildid=list(session.guildid for session in sessions)
).set(closed_at=utc_now())
# Perform economy transactions
economy: Economy = self.bot.get_cog('Economy')
transactions = await economy.data.Transaction.execute_transactions(*rewards)
reward_ids = {
(t.guildid, t.to_account): t.transactionid
for t in transactions
}
# Update lobby messages
message_tasks = [
asyncio.create_task(session.update_status(save=False))
for session in sessions
if session.lobby_channel is not None
]
await asyncio.gather(*message_tasks)
# Save attendance
if attendance:
att_table = TemporaryTable(
'_sid', '_gid', '_uid', '_att', '_clock', '_reward',
types=('INTEGER', 'BIGINT', 'BIGINT', 'BOOLEAN', 'INTEGER', 'INTEGER')
)
att_table.values = [
(sid, gid, uid, att, clock, reward_ids.get((gid, uid), None))
for sid, gid, uid, att, clock in attendance
]
await self.data.ScheduleSessionMember.table.update_where(
slotid=att_table['_sid'],
guildid=att_table['_gid'],
userid=att_table['_uid'],
).set(
attended=att_table['_att'],
clock=att_table['_clock'],
reward_transactionid=att_table['_reward']
).from_expr(att_table)
# Mark guild sessions as closed
if sessions:
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid,
guildid=list(session.guildid for session in sessions)
).set(closed_at=utc_now())
if consequences and did_not_show:
# Trigger blacklist and cancel member bookings as needed
@@ -532,36 +531,35 @@ class TimeSlot:
This involves refunding the booking transactions, deleting the booking rows,
and updating any messages that may have been posted.
"""
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Collect booking rows
bookings = [member.data for session in sessions for member in session.members.values()]
# TODO: Transaction
# Collect booking rows
bookings = [member.data for session in sessions for member in session.members.values()]
if bookings:
# Refund booking transactions
economy: Economy = self.bot.get_cog('Economy')
maybe_tids = (r.book_transactionid for r in bookings)
tids = [tid for tid in maybe_tids if tid is not None]
await economy.data.Transaction.refund_transactions(*tids)
if bookings:
# Refund booking transactions
economy: Economy = self.bot.get_cog('Economy')
maybe_tids = (r.book_transactionid for r in bookings)
tids = [tid for tid in maybe_tids if tid is not None]
await economy.data.Transaction.refund_transactions(*tids)
# Delete booking rows
await self.data.ScheduleSessionMember.table.delete_where(
MEMBERS(*((r.guildid, r.userid) for r in bookings)),
slotid=self.slotid,
)
# Trigger message update for existent messages
lobby_tasks = [
asyncio.create_task(session.update_status(save=False, resend=False))
for session in sessions
]
await asyncio.gather(*lobby_tasks)
# Mark sessions as closed
await self.data.ScheduleSession.table.update_where(
# Delete booking rows
await self.data.ScheduleSessionMember.table.delete_where(
MEMBERS(*((r.guildid, r.userid) for r in bookings)),
slotid=self.slotid,
guildid=[session.guildid for session in sessions]
).set(
closed_at=utc_now()
)
# TODO: Logging
# Trigger message update for existent messages
lobby_tasks = [
asyncio.create_task(session.update_status(save=False, resend=False))
for session in sessions
]
await asyncio.gather(*lobby_tasks)
# Mark sessions as closed
await self.data.ScheduleSession.table.update_where(
slotid=self.slotid,
guildid=[session.guildid for session in sessions]
).set(
closed_at=utc_now()
)
# TODO: Logging

View File

@@ -10,6 +10,7 @@ from discord.ui.button import button, Button
from meta import LionCog, LionContext, LionBot
from meta.errors import SafeCancellation
from meta.logger import log_wrap
from utils import ui
from utils.lib import error_embed
from constants import MAX_COINS
@@ -145,6 +146,7 @@ class ColourShop(Shop):
if (owned is None or item.itemid != owned.itemid) and (item.price <= balance)
]
@log_wrap(action='purchase')
async def purchase(self, itemid) -> ColourRoleItem:
"""
Atomically handle a purchase of a ColourRoleItem.
@@ -157,144 +159,145 @@ class ColourShop(Shop):
If the purchase fails for a known reason, raises SafeCancellation, with the error information.
"""
t = self.bot.translator.t
conn = await self.bot.db.get_connection()
async with conn.transaction():
# Retrieve the item to purchase from data
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
# Ensure the item is purchasable and not deleted
if not item['purchasable'] or item['deleted']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:not_purchasable',
"This item may not be purchased!"
))
)
# Refresh the customer
await self.customer.refresh()
# Ensure the guild exists in cache
guild = self.bot.get_guild(self.customer.guildid)
if guild is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_guild',
"Could not retrieve the server from Discord!"
))
)
# Ensure the customer member actually exists
member = await self.customer.lion.fetch_member()
if member is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_member',
"Could not retrieve the member from Discord."
))
)
# Ensure the purchased role actually exists
role = guild.get_role(item['roleid'])
if role is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_role',
"This colour role could not be found in the server."
))
)
# Ensure the customer has enough coins for the item
if self.customer.balance < item['price']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:low_balance',
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
)).format(
coin=self.bot.config.emojis.getemoji('coin'),
amount=item['price'],
balance=self.customer.balance
)
)
owned = self.owned()
if owned is not None:
# Ensure the customer does not already own the item
if owned.itemid == item['itemid']:
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
# Retrieve the item to purchase from data
item = await self.data.ShopItemInfo.table.select_one_where(itemid=itemid)
# Ensure the item is purchasable and not deleted
if not item['purchasable'] or item['deleted']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:owned',
"You already own this item!"
'shop:colour|purchase|error:not_purchasable',
"This item may not be purchased!"
))
)
# Charge the customer for the item
economy_cog: Economy = self.bot.get_cog('Economy')
economy_data = economy_cog.data
transaction = await economy_data.ShopTransaction.purchase_transaction(
guild.id,
member.id,
member.id,
itemid,
item['price']
)
# Refresh the customer
await self.customer.refresh()
# Add the item to the customer's inventory
await self.data.MemberInventory.create(
guildid=guild.id,
userid=member.id,
transactionid=transaction.transactionid,
itemid=itemid
)
# Ensure the guild exists in cache
guild = self.bot.get_guild(self.customer.guildid)
if guild is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_guild',
"Could not retrieve the server from Discord!"
))
)
# Give the customer the role (do rollback if this fails)
try:
await member.add_roles(
role,
atomic=True,
reason="Purchased colour role"
)
except discord.NotFound:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_no_role',
"This colour role no longer exists!"
))
)
except discord.Forbidden:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_permissions',
"I do not have enough permissions to give you this colour role!"
))
)
except discord.HTTPException:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_unknown',
"An unknown error occurred while giving you this colour role!"
))
)
# Ensure the customer member actually exists
member = await self.customer.lion.fetch_member()
if member is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_member',
"Could not retrieve the member from Discord."
))
)
# At this point, the purchase has succeeded and the user has obtained the colour role
# Now, remove their previous colour role (if applicable)
# TODO: We should probably add an on_role_delete event to clear defunct colour roles
if owned is not None:
owned_role = owned.role
if owned_role is not None:
try:
await member.remove_roles(
owned_role,
reason="Removing old colour role.",
atomic=True
# Ensure the purchased role actually exists
role = guild.get_role(item['roleid'])
if role is None:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:no_role',
"This colour role could not be found in the server."
))
)
# Ensure the customer has enough coins for the item
if self.customer.balance < item['price']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:low_balance',
"This item costs {coin}{amount}!\nYour balance is {coin}{balance}"
)).format(
coin=self.bot.config.emojis.getemoji('coin'),
amount=item['price'],
balance=self.customer.balance
)
except discord.HTTPException:
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
pass
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
)
# Purchase complete, update the shop and customer
await self.refresh()
return self.owned()
owned = self.owned()
if owned is not None:
# Ensure the customer does not already own the item
if owned.itemid == item['itemid']:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:owned',
"You already own this item!"
))
)
# Charge the customer for the item
economy_cog: Economy = self.bot.get_cog('Economy')
economy_data = economy_cog.data
transaction = await economy_data.ShopTransaction.purchase_transaction(
guild.id,
member.id,
member.id,
itemid,
item['price']
)
# Add the item to the customer's inventory
await self.data.MemberInventory.create(
guildid=guild.id,
userid=member.id,
transactionid=transaction.transactionid,
itemid=itemid
)
# Give the customer the role (do rollback if this fails)
try:
await member.add_roles(
role,
atomic=True,
reason="Purchased colour role"
)
except discord.NotFound:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_no_role',
"This colour role no longer exists!"
))
)
except discord.Forbidden:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_permissions',
"I do not have enough permissions to give you this colour role!"
))
)
except discord.HTTPException:
raise SafeCancellation(
t(_p(
'shop:colour|purchase|error:failed_unknown',
"An unknown error occurred while giving you this colour role!"
))
)
# At this point, the purchase has succeeded and the user has obtained the colour role
# Now, remove their previous colour role (if applicable)
# TODO: We should probably add an on_role_delete event to clear defunct colour roles
if owned is not None:
owned_role = owned.role
if owned_role is not None:
try:
await member.remove_roles(
owned_role,
reason="Removing old colour role.",
atomic=True
)
except discord.HTTPException:
# Possibly Forbidden, or the role doesn't actually exist anymore (cache failure)
pass
await self.data.MemberInventory.table.delete_where(inventoryid=owned.data.inventoryid)
# Purchase complete, update the shop and customer
await self.refresh()
return self.owned()
async def refresh(self):
"""

View File

@@ -4,6 +4,7 @@ from enum import Enum
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table, RegisterEnum
from data.columns import Integer, String, Timestamp, Bool, Column
@@ -80,6 +81,7 @@ class StatsData(Registry):
end_time = Timestamp()
@classmethod
@log_wrap(action='tracked_time_between')
async def tracked_time_between(cls, *points: tuple[int, int, dt.datetime, dt.datetime]):
query = sql.SQL(
"""
@@ -103,25 +105,27 @@ class StatsData(Registry):
for _ in points
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
chain(*points)
)
return cursor.fetchall()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
chain(*points)
)
return cursor.fetchall()
@classmethod
@log_wrap(action='study_time_between')
async def study_time_between(cls, guildid: int, userid: int, _start, _end) -> int:
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_between(%s, %s, %s, %s)",
(guildid, userid, _start, _end)
)
return (await cursor.fetchone()[0]) or 0
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_between(%s, %s, %s, %s)",
(guildid, userid, _start, _end)
)
return (await cursor.fetchone()[0]) or 0
@classmethod
@log_wrap(action='study_times_between')
async def study_times_between(cls, guildid: int, userid: int, *points) -> list[int]:
if len(points) < 2:
raise ValueError('Not enough block points given!')
@@ -141,25 +145,27 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), *blocks))
)
return [r['stime'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), *blocks))
)
return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s)",
(guildid, userid, _start)
)
return (await cursor.fetchone()[0]) or 0
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s)",
(guildid, userid, _start)
)
return (await cursor.fetchone()[0]) or 0
@classmethod
@log_wrap(action='study_times_between')
async def study_times_since(cls, guildid: int, userid: int, *starts) -> int:
if len(starts) < 1:
raise ValueError('No starting points given!')
@@ -178,15 +184,16 @@ class StatsData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in starts
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), starts))
)
return [r['stime'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), starts))
)
return [r['stime'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the voice totals since the given time for each member in the guild.
@@ -226,23 +233,25 @@ class StatsData(Registry):
)
second_query_args = (since, guildid, since, since)
conn = await cls._connector.get_connection()
async with conn.transaction():
async with conn.cursor() as cursor:
await cursor.execute(second_query, second_query_args)
overshoot_rows = await cursor.fetchall()
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows}
async with cls._connector.connection() as conn:
cls._connector.conn = conn
async with conn.transaction():
async with conn.cursor() as cursor:
await cursor.execute(second_query, second_query_args)
overshoot_rows = await cursor.fetchall()
overshoot = {row['userid']: int(row['diff']) for row in overshoot_rows}
async with conn.cursor() as cursor:
await cursor.execute(first_query, first_query_args)
leaderboard = [
(row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0)))
for row in await cursor.fetchall()
]
async with conn.cursor() as cursor:
await cursor.execute(first_query, first_query_args)
leaderboard = [
(row['userid'], int(row['total_duration'] - overshoot.get(row['userid'], 0)))
for row in await cursor.fetchall()
]
leaderboard.sort(key=lambda t: t[1], reverse=True)
return leaderboard
@classmethod
@log_wrap('leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time voice totals for the given guild.
@@ -257,13 +266,13 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
(row['userid'], int(row['total_duration']))
for row in await cursor.fetchall()
]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
(row['userid'], int(row['total_duration']))
for row in await cursor.fetchall()
]
return leaderboard
class MemberExp(RowModel):
@@ -296,6 +305,7 @@ class StatsData(Registry):
transactionid = Integer()
@classmethod
@log_wrap(action='xp_since')
async def xp_since(cls, guildid: int, userid: int, *starts):
query = sql.SQL(
"""
@@ -320,15 +330,16 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), starts))
)
return [r['exp'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), starts))
)
return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='xp_between')
async def xp_between(cls, guildid: int, userid: int, *points):
blocks = zip(points, points[1:])
query = sql.SQL(
@@ -355,15 +366,16 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), *blocks))
)
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((guildid, userid), *blocks))
)
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='leaderboard_since')
async def leaderboard_since(cls, guildid: int, since):
"""
Return the XP totals for the given guild since the given time.
@@ -378,16 +390,17 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since))
leaderboard = [
(row['userid'], int(row['total_xp']))
for row in await cursor.fetchall()
]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, since))
leaderboard = [
(row['userid'], int(row['total_xp']))
for row in await cursor.fetchall()
]
return leaderboard
@classmethod
@log_wrap(action='leaderboard_all')
async def leaderboard_all(cls, guildid: int):
"""
Return the all-time XP totals for the given guild.
@@ -402,13 +415,13 @@ class StatsData(Registry):
"""
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
(row['userid'], int(row['total_xp']))
for row in await cursor.fetchall()
]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(query, (guildid, ))
leaderboard = [
(row['userid'], int(row['total_xp']))
for row in await cursor.fetchall()
]
return leaderboard
class UserExp(RowModel):
@@ -436,6 +449,7 @@ class StatsData(Registry):
exp_type: Column[ExpType] = Column()
@classmethod
@log_wrap(action='user_xp_since')
async def xp_since(cls, userid: int, *starts):
query = sql.SQL(
"""
@@ -459,15 +473,16 @@ class StatsData(Registry):
sql.Placeholder() for _ in starts
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), starts))
)
return [r['exp'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), starts))
)
return [r['exp'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='user_xp_since')
async def xp_between(cls, userid: int, *points):
blocks = zip(points, points[1:])
query = sql.SQL(
@@ -493,13 +508,13 @@ class StatsData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), *blocks))
)
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), *blocks))
)
return [r['period_xp'] or 0 for r in await cursor.fetchall()]
class ProfileTag(RowModel):
"""
@@ -531,15 +546,17 @@ class StatsData(Registry):
return [tag.tag for tag in tags]
@classmethod
@log_wrap(action='set_profile_tags')
async def set_tags(self, guildid: Optional[int], userid: int, tags: Iterable[str]):
conn = await self._connector.get_connection()
async with conn.transaction():
await self.table.delete_where(guildid=guildid, userid=userid)
if tags:
await self.table.insert_many(
('guildid', 'userid', 'tag'),
*((guildid, userid, tag) for tag in tags)
)
async with self._connector.connection() as conn:
self._connector.conn = conn
async with conn.transaction():
await self.table.delete_where(guildid=guildid, userid=userid)
if tags:
await self.table.insert_many(
('guildid', 'userid', 'tag'),
*((guildid, userid, tag) for tag in tags)
)
class WeeklyGoals(RowModel):
"""

View File

@@ -473,14 +473,14 @@ class WeeklyMonthlyUI(StatsUI):
# Update the tasklist
if len(new_tasks) != len(tasks) or not all(t == new_t for (t, new_t) in zip(tasks, new_tasks)):
modified = True
conn = await self.bot.db.get_connection()
async with conn.transaction():
await tasks_model.table.delete_where(**key)
if new_tasks:
await tasks_model.table.insert_many(
(*key.keys(), 'completed', 'content'),
*((*key.values(), *new_task) for new_task in new_tasks)
)
async with self._connector.connection() as conn:
async with conn.transaction():
await tasks_model.table.delete_where(**key).with_connection(conn)
if new_tasks:
await tasks_model.table.insert_many(
(*key.keys(), 'completed', 'content'),
*((*key.values(), *new_task) for new_task in new_tasks)
).with_connection(conn)
if modified:
# If either goal type was modified, clear the rendered cache and refresh

View File

@@ -8,6 +8,7 @@ from discord import app_commands as appcmds
from discord.app_commands.transformers import AppCommandOptionType as cmdopt
from meta import LionBot, LionCog, LionContext
from meta.logger import log_wrap
from meta.errors import UserInputError
from utils.lib import utc_now, error_embed
from utils.ui import ChoicedEnum, Transformed, AButton
@@ -141,30 +142,32 @@ class TasklistCog(LionCog):
self.crossload_group(self.configure_group, configcog.configure_group)
@LionCog.listener('on_tasks_completed')
@log_wrap(action="reward tasks completed")
async def reward_tasks_completed(self, member: discord.Member, *taskids: int):
conn = await self.bot.db.get_connection()
async with conn.transaction():
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
tasks = await tasklist.fetch_tasks(*taskids)
unrewarded = [task for task in tasks if not task.rewarded]
if unrewarded:
reward = (await self.settings.task_reward.get(member.guild.id)).value
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
tasklist = await Tasklist.fetch(self.bot, self.data, member.id)
tasks = await tasklist.fetch_tasks(*taskids)
unrewarded = [task for task in tasks if not task.rewarded]
if unrewarded:
reward = (await self.settings.task_reward.get(member.guild.id)).value
limit = (await self.settings.task_reward_limit.get(member.guild.id)).value
ecog = self.bot.get_cog('Economy')
recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0
max_to_reward = limit - recent
if max_to_reward > 0:
to_reward = unrewarded[:max_to_reward]
ecog = self.bot.get_cog('Economy')
recent = await ecog.data.TaskTransaction.count_recent_for(member.id, member.guild.id) or 0
max_to_reward = limit - recent
if max_to_reward > 0:
to_reward = unrewarded[:max_to_reward]
count = len(to_reward)
amount = count * reward
await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount)
await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True)
logger.debug(
f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> "
f"'{amount}' coins for completing '{count}' tasks."
)
count = len(to_reward)
amount = count * reward
await ecog.data.TaskTransaction.reward_completed(member.id, member.guild.id, count, amount)
await tasklist.update_tasks(*(task.taskid for task in to_reward), rewarded=True)
logger.debug(
f"Rewarded <uid: {member.id}> in <gid: {member.guild.id}> "
f"'{amount}' coins for completing '{count}' tasks."
)
async def is_tasklist_channel(self, channel) -> bool:
if not channel.guild:
@@ -477,43 +480,40 @@ class TasklistCog(LionCog):
# Contents successfully parsed, update the tasklist.
tasklist = await Tasklist.fetch(self.bot, self.data, ctx.author.id)
# Lazily using the editor because it has a good parser
taskinfo = tasklist.parse_tasklist(lines)
conn = await self.bot.db.get_connection()
async with conn.transaction():
now = utc_now()
now = utc_now()
# Delete tasklist if required
if not append:
await tasklist.update_tasklist(deleted_at=now)
# Delete tasklist if required
if not append:
await tasklist.update_tasklist(deleted_at=now)
# Create tasklist
# TODO: Refactor into common method with parse tasklist
created = {}
target_depth = 0
while True:
to_insert = {}
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
if truedepth == target_depth:
to_insert[i] = (
tasklist.userid,
content,
created[parent] if parent is not None else None,
now if ticked else None
)
if to_insert:
# Batch insert
tasks = await tasklist.data.Task.table.insert_many(
('userid', 'content', 'parentid', 'completed_at'),
*to_insert.values()
# Create tasklist
# TODO: Refactor into common method with parse tasklist
created = {}
target_depth = 0
while True:
to_insert = {}
for i, (parent, truedepth, ticked, content) in enumerate(taskinfo):
if truedepth == target_depth:
to_insert[i] = (
tasklist.userid,
content,
created[parent] if parent is not None else None,
now if ticked else None
)
for i, task in zip(to_insert.keys(), tasks):
created[i] = task['taskid']
target_depth += 1
else:
# Reached maximum depth
break
if to_insert:
# Batch insert
tasks = await tasklist.data.Task.table.insert_many(
('userid', 'content', 'parentid', 'completed_at'),
*to_insert.values()
)
for i, task in zip(to_insert.keys(), tasks):
created[i] = task['taskid']
target_depth += 1
else:
# Reached maximum depth
break
# Ack modifications
embed = discord.Embed(

View File

@@ -11,6 +11,7 @@ from discord.ui.button import button, Button, ButtonStyle
from discord.ui.text_input import TextInput, TextStyle
from meta import conf
from meta.logger import log_wrap
from meta.errors import UserInputError
from utils.lib import MessageArgs, utc_now
from utils.ui import LeoUI, LeoModal, FastModal, error_handler_for, ModalRetryUI
@@ -143,6 +144,7 @@ class BulkEditor(LeoModal):
except UserInputError as error:
await ModalRetryUI(self, error.msg).respond_to(interaction)
@log_wrap(action="parse editor")
async def parse_editor(self):
# First parse each line
new_lines = self.tasklist_editor.value.splitlines()
@@ -155,27 +157,28 @@ class BulkEditor(LeoModal):
)
# TODO: Incremental/diff editing
conn = await self.bot.db.get_connection()
async with conn.transaction():
now = utc_now()
async with self.bot.db.connection() as conn:
self.bot.db.conn = conn
async with conn.transaction():
now = utc_now()
if same_layout:
# if the layout has not changed, just edit the tasks
for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)):
args = {}
if oldinfo[2] != newinfo[2]:
args['completed_at'] = now if newinfo[2] else None
if oldinfo[3] != newinfo[3]:
args['content'] = newinfo[3]
if args:
await self.tasklist.update_tasks(taskid, **args)
else:
# Naive implementation clearing entire tasklist
# Clear tasklist
await self.tasklist.update_tasklist(deleted_at=now)
if same_layout:
# if the layout has not changed, just edit the tasks
for taskid, (oldinfo, newinfo) in zip(self.lines.keys(), zip(old_info, taskinfo)):
args = {}
if oldinfo[2] != newinfo[2]:
args['completed_at'] = now if newinfo[2] else None
if oldinfo[3] != newinfo[3]:
args['content'] = newinfo[3]
if args:
await self.tasklist.update_tasks(taskid, **args)
else:
# Naive implementation clearing entire tasklist
# Clear tasklist
await self.tasklist.update_tasklist(deleted_at=now)
# Create tasklist
await self.tasklist.write_taskinfo(taskinfo)
# Create tasklist
await self.tasklist.write_taskinfo(taskinfo)
class UIMode(Enum):

View File

@@ -2,6 +2,7 @@ from typing import Type
import json
from data import RowModel, Table, ORDER
from meta.logger import log_wrap, set_logging_context
class ModelData:
@@ -60,6 +61,7 @@ class ModelData:
It only updates.
"""
# TODO: Better way of getting the key?
# TODO: Transaction
if not isinstance(parent_id, tuple):
parent_id = (parent_id, )
model = cls._model
@@ -83,6 +85,8 @@ class ListData:
This assumes the list is the only data stored in the table,
and removes list entries by deleting rows.
"""
setting_id: str
# Table storing the setting data
_table_interface: Table
@@ -100,10 +104,12 @@ class ListData:
_cache = None # Map[id -> value]
@classmethod
@log_wrap(isolate=True)
async def _reader(cls, parent_id, use_cache=True, **kwargs):
"""
Read in all entries associated to the given id.
"""
set_logging_context(action="Read cls.setting_id")
if cls._cache is not None and parent_id in cls._cache and use_cache:
return cls._cache[parent_id]
@@ -121,53 +127,56 @@ class ListData:
return data
@classmethod
@log_wrap(isolate=True)
async def _writer(cls, id, data, add_only=False, remove_only=False, **kwargs):
"""
Write the provided list to storage.
"""
set_logging_context(action="Write cls.setting_id")
table = cls._table_interface
conn = await table.connector.get_connection()
async with conn.transaction():
# Handle None input as an empty list
if data is None:
data = []
async with table.connector.connection() as conn:
table.connector.conn = conn
async with conn.transaction():
# Handle None input as an empty list
if data is None:
data = []
current = await cls._reader(id, use_cache=False, **kwargs)
if not cls._order_column and (add_only or remove_only):
to_insert = [item for item in data if item not in current] if not remove_only else []
to_remove = data if remove_only else (
[item for item in current if item not in data] if not add_only else []
)
current = await cls._reader(id, use_cache=False, **kwargs)
if not cls._order_column and (add_only or remove_only):
to_insert = [item for item in data if item not in current] if not remove_only else []
to_remove = data if remove_only else (
[item for item in current if item not in data] if not add_only else []
)
# Handle required deletions
if to_remove:
params = {
cls._id_column: id,
cls._data_column: to_remove
}
await table.delete_where(**params)
# Handle required deletions
if to_remove:
params = {
cls._id_column: id,
cls._data_column: to_remove
}
await table.delete_where(**params)
# Handle required insertions
if to_insert:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in to_insert]
await table.insert_many(columns, *values)
# Handle required insertions
if to_insert:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in to_insert]
await table.insert_many(columns, *values)
if cls._cache is not None:
new_current = [item for item in current + to_insert if item not in to_remove]
cls._cache[id] = new_current
else:
# Remove all and add all to preserve order
delete_params = {cls._id_column: id}
await table.delete_where(**delete_params)
if cls._cache is not None:
new_current = [item for item in current + to_insert if item not in to_remove]
cls._cache[id] = new_current
else:
# Remove all and add all to preserve order
delete_params = {cls._id_column: id}
await table.delete_where(**delete_params)
if data:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in data]
await table.insert_many(columns, *values)
if data:
columns = (cls._id_column, cls._data_column)
values = [(id, value) for value in data]
await table.insert_many(columns, *values)
if cls._cache is not None:
cls._cache[id] = data
if cls._cache is not None:
cls._cache[id] = data
class KeyValueData:

View File

@@ -1,7 +1,7 @@
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool
@@ -72,6 +72,7 @@ class TextTrackerData(Registry):
member_expid = Integer()
@classmethod
@log_wrap(action='end_text_sessions')
async def end_sessions(cls, connector, *session_data):
query = sql.SQL("""
WITH
@@ -92,7 +93,7 @@ class TextTrackerData(Registry):
) SELECT
data._guildid, 0,
NULL, data._userid,
SUM(_coins), 0, 'TEXT_SESSION'
LEAST(SUM(_coins :: BIGINT), 2147483647), 0, 'TEXT_SESSION'
FROM data
WHERE data._coins > 0
GROUP BY (data._guildid, data._userid)
@@ -100,7 +101,7 @@ class TextTrackerData(Registry):
)
, member AS (
UPDATE members
SET coins = coins + data._coins
SET coins = LEAST(coins :: BIGINT + data._coins :: BIGINT, 2147483647)
FROM data
WHERE members.userid = data._userid AND members.guildid = data._guildid
)
@@ -166,15 +167,16 @@ class TextTrackerData(Registry):
# Or ask for a connection from the connection pool
# Transaction may take some time due to index updates
# Alternatively maybe use the "do not expect response mode"
conn = await connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*session_data))
)
async with connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*session_data))
)
return
@classmethod
@log_wrap(action='user_messages_between')
async def user_messages_between(cls, userid: int, *points):
"""
Compute messages written between the given points.
@@ -203,15 +205,16 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), *blocks))
)
return [r['period_m'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid,), *blocks))
)
return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='member_messages_between')
async def member_messages_between(cls, guildid: int, userid: int, *points):
"""
Compute messages written between the given points.
@@ -241,15 +244,16 @@ class TextTrackerData(Registry):
sql.SQL("({}, {})").format(sql.Placeholder(), sql.Placeholder()) for _ in points[1:]
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid, guildid), *blocks))
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid, guildid), *blocks))
)
return [r['period_m'] or 0 for r in await cursor.fetchall()]
@classmethod
@log_wrap(action='member_messages_since')
async def member_messages_since(cls, guildid: int, userid: int, *points):
"""
Compute messages written between the given points.
@@ -277,12 +281,12 @@ class TextTrackerData(Registry):
sql.SQL("({})").format(sql.Placeholder()) for _ in points
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid, guildid), points))
)
return [r['messages'] or 0 for r in await cursor.fetchall()]
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain((userid, guildid), points))
)
return [r['messages'] or 0 for r in await cursor.fetchall()]
untracked_channels = Table('untracked_text_channels')

View File

@@ -2,6 +2,7 @@ import datetime as dt
from itertools import chain
from psycopg import sql
from meta.logger import log_wrap
from data import RowModel, Registry, Table
from data.columns import Integer, String, Timestamp, Bool
@@ -113,16 +114,18 @@ class VoiceTrackerData(Registry):
hourly_coins = Integer()
@classmethod
@log_wrap(action='close_voice_session')
async def close_study_session_at(cls, guildid: int, userid: int, _at: dt.datetime) -> int:
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT close_study_session_at(%s, %s, %s)",
(guildid, userid, _at)
)
member_data = await cursor.fetchone()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT close_study_session_at(%s, %s, %s)",
(guildid, userid, _at)
)
return await cursor.fetchone()
@classmethod
@log_wrap(action='close_voice_sessions')
async def close_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL("""
SELECT
@@ -139,28 +142,30 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
@classmethod
@log_wrap(action='update_voice_session')
async def update_voice_session_at(
cls, guildid: int, userid: int, _at: dt.datetime,
stream: bool, video: bool, rate: float
) -> int:
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
(guildid, userid, _at, stream, video, rate)
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT * FROM update_voice_session(%s, %s, %s, %s, %s, %s)",
(guildid, userid, _at, stream, video, rate)
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
@classmethod
@log_wrap(action='update_voice_sessions')
async def update_voice_sessions_at(cls, *arg_tuples):
query = sql.SQL("""
UPDATE
@@ -209,14 +214,14 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
rows = await cursor.fetchall()
return cls._make_rows(*rows)
class VoiceSessions(RowModel):
"""
@@ -257,17 +262,19 @@ class VoiceTrackerData(Registry):
transactionid = Integer()
@classmethod
@log_wrap(action='study_time_since')
async def study_time_since(cls, guildid: int, userid: int, _start) -> int:
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s) AS result",
(guildid, userid, _start)
)
result = await cursor.fetchone()
return (result['result'] or 0) if result else 0
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
"SELECT study_time_since(%s, %s, %s) AS result",
(guildid, userid, _start)
)
result = await cursor.fetchone()
return (result['result'] or 0) if result else 0
@classmethod
@log_wrap(action='multiple_voice_tracked_since')
async def multiple_voice_tracked_since(cls, *arg_tuples):
query = sql.SQL("""
SELECT
@@ -286,13 +293,13 @@ class VoiceTrackerData(Registry):
for _ in arg_tuples
)
)
conn = await cls._connector.get_connection()
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
return await cursor.fetchall()
async with cls._connector.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
query,
tuple(chain(*arg_tuples))
)
return await cursor.fetchall()
"""
Schema

View File

@@ -161,7 +161,6 @@ class VoiceSession:
self.state.channelid, guildid=self.guildid, deleted=False
)
conn = await self.bot.db.get_connection()
# Insert an ongoing_session with the correct state, set data
state = self.state
self.data = await self.registry.VoiceSessionsOngoing.create(