rewrite (settings): New types and bugfixes.

This commit is contained in:
2023-05-14 12:25:59 +03:00
parent 87874e1130
commit c5302adf66
3 changed files with 69 additions and 15 deletions

View File

@@ -70,8 +70,7 @@ class ModelData:
) )
# If we didn't update any rows, create a new row # If we didn't update any rows, create a new row
if not rows: if not rows:
await model.table.fetch_or_create(**model._dict_from_id, **{cls._column: data}) await model.fetch_or_create(**model._dict_from_id(parent_id), **{cls._column: data})
...
if cls._cache is not None: if cls._cache is not None:
cls._cache[parent_id] = data cls._cache[parent_id] = data

View File

@@ -4,10 +4,12 @@ from enum import Enum
import pytz import pytz
import discord import discord
import itertools import itertools
import datetime as dt
from discord import ui from discord import ui
from discord.ui.button import button, Button, ButtonStyle from discord.ui.button import button, Button, ButtonStyle
from dateutil.parser import parse, ParserError
from meta.context import context from meta.context import ctx_bot
from meta.errors import UserInputError from meta.errors import UserInputError
from utils.lib import strfdur, parse_duration from utils.lib import strfdur, parse_duration
from babel import ctx_translator from babel import ctx_translator
@@ -139,8 +141,8 @@ class ChannelSetting(Generic[ParentID, CT], InteractiveSetting[ParentID, int, CT
If the channel cannot be found, returns a `discord.Object` instead. If the channel cannot be found, returns a `discord.Object` instead.
""" """
if data is not None: if data is not None:
ctx = context.get() bot = ctx_bot.get()
channel = ctx.bot.get_channel(data) channel = bot.get_channel(data)
if channel is None: if channel is None:
channel = discord.Object(id=data) channel = discord.Object(id=data)
return channel return channel
@@ -158,7 +160,7 @@ class ChannelSetting(Generic[ParentID, CT], InteractiveSetting[ParentID, int, CT
if data: if data:
return "<#{}>".format(data) return "<#{}>".format(data)
else: else:
return None return "Not Set"
@property @property
def input_formatted(self) -> str: def input_formatted(self) -> str:
@@ -229,8 +231,8 @@ class MessageablelSetting(ChannelSetting):
If the channel cannot be found, returns a `discord.PartialMessageable` instead. If the channel cannot be found, returns a `discord.PartialMessageable` instead.
""" """
if data is not None: if data is not None:
ctx = context.get() bot = ctx_bot.get()
channel = ctx.bot.get_channel(data) channel = bot.get_channel(data)
if channel is None: if channel is None:
channel = ctx.bot.get_partial_messageable(data, guild_id=parent_id) channel = ctx.bot.get_partial_messageable(data, guild_id=parent_id)
return channel return channel
@@ -277,8 +279,8 @@ class RoleSetting(InteractiveSetting[ParentID, int, Union[discord.Role, discord.
role = None role = None
guildid = cls._get_guildid(parent_id) guildid = cls._get_guildid(parent_id)
ctx = context.get() bot = ctx_bot.get()
guild = ctx.bot.get_guild(guildid) guild = bot.get_guild(guildid)
if guild is not None: if guild is not None:
role = guild.get_role(data) role = guild.get_role(data)
if role is None: if role is None:
@@ -650,8 +652,8 @@ class GuildIDSetting(InteractiveSetting[ParentID, int, int]):
If the guild is in cache, attach the name as well. If the guild is in cache, attach the name as well.
""" """
if data is not None: if data is not None:
ctx = context.get() bot = ctx_bot.get()
guild = ctx.bot.get_guild(data) guild = bot.get_guild(data)
if guild is not None: if guild is not None:
return f"`{data}` ({guild.name})" return f"`{data}` ({guild.name})"
else: else:
@@ -670,7 +672,6 @@ class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]):
# Maybe list e.g. Europe (Austria - Iceland) and Europe (Ireland - Ukraine) separately # Maybe list e.g. Europe (Austria - Iceland) and Europe (Ireland - Ukraine) separately
# TODO Definitely need autocomplete here # TODO Definitely need autocomplete here
accepts = "A timezone name."
_accepts = ( _accepts = (
"A timezone name from [this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) " "A timezone name from [this list](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) "
"(e.g. `Europe/London`)." "(e.g. `Europe/London`)."
@@ -745,6 +746,60 @@ class TimezoneSetting(InteractiveSetting[ParentID, str, TZT]):
return f"`{data}`" return f"`{data}`"
class TimestampSetting(InteractiveSetting[ParentID, str, dt.datetime]):
"""
Typed Setting ABC representing a fixed point in time.
Data is assumed to be a timezone aware datetime object.
Value is the same as data.
Parsing accepts YYYY-MM-DD [HH:MM] [+TZ]
Display uses a discord timestamp.
"""
_accepts = "A timestamp in the form yyyy-mm-dd HH:MM"
@classmethod
def _data_from_value(cls, parent_id: ParentID, value, **kwargs):
return value
@classmethod
def _data_to_value(cls, parent_id: ParentID, data, **kwargs):
return data
@classmethod
async def _parse_string(cls, parent_id: ParentID, string: str, **kwargs):
string = string.strip()
if string.lower() in ('', 'none', '0'):
ts = None
else:
local_tz = await cls._timezone_from_id(parent_id, **kwargs)
default = dt.datetime.now(tz=local_tz).replace(
hour=0, minute=0,
second=0, microsecond=0
)
try:
ts = parse(string, fuzzy=True, default=default)
except ParserError:
# TOLOCALISE:
raise UserInputError("Invalid date string passed")
return ts
@classmethod
def _format_data(cls, parent_id: ParentID, data, **kwargs):
if data is None:
return "Not Set"
else:
return "<t:{}>".format(int(data.timestamp()))
@classmethod
async def _timezone_from_id(cls, parent_id: ParentID, **kwargs):
"""
Extract the parsing timezone from the given parent id.
Should generally be overriden for interactive settings.
"""
return pytz.UTC
ET = TypeVar('ET', bound='Enum') ET = TypeVar('ET', bound='Enum')
@@ -1046,7 +1101,7 @@ class ChannelListSetting(ListSetting, InteractiveSetting):
_setting = ChannelSetting _setting = ChannelSetting
class RoleListSetting(InteractiveSetting, ListSetting): class RoleListSetting(ListSetting, InteractiveSetting):
""" """
List of roles List of roles
""" """

View File

@@ -257,7 +257,7 @@ class InteractiveSetting(BaseSetting[ParentID, SettingData, SettingValue]):
return '\n'.join(( return '\n'.join((
self.display_name, self.display_name,
'=' * len(self.display_name), '=' * len(self.display_name),
self.long_desc, self.desc,
f"\nAccepts: {self.accepts}" f"\nAccepts: {self.accepts}"
)) ))