From af5a827710c2e0096ef2ae5272006a56d58a5fc1 Mon Sep 17 00:00:00 2001 From: Interitio Date: Tue, 2 Dec 2025 16:17:02 +1000 Subject: [PATCH] Add ABCs for records and transactions --- src/base/__init__.py | 6 + src/base/converter.py | 108 ++++++++++++++++ src/base/partial.py | 269 ++++++++++++++++++++++++++++++++++++++++ src/base/record.py | 96 ++++++++++++++ src/base/rules.py | 135 ++++++++++++++++++++ src/base/transaction.py | 164 ++++++++++++++++++++++++ 6 files changed, 778 insertions(+) create mode 100644 src/base/partial.py create mode 100644 src/base/record.py create mode 100644 src/base/transaction.py diff --git a/src/base/__init__.py b/src/base/__init__.py index e69de29..f2e6b4d 100644 --- a/src/base/__init__.py +++ b/src/base/__init__.py @@ -0,0 +1,6 @@ +from .rules import RuleSet, Rule, RuleInterface +from .transaction import Transaction, Amount, TXNFlag, TXNPosting +from .record import Record +from .converter import Converter +# from .beancounter import BeanCounter +from .partial import PartialPosting, PartialTXN diff --git a/src/base/converter.py b/src/base/converter.py index e69de29..21c5ba6 100644 --- a/src/base/converter.py +++ b/src/base/converter.py @@ -0,0 +1,108 @@ +from enum import Enum +from typing import ClassVar, Literal, Self, Type, Optional +from . import RuleSet, Record, Transaction +from .partial import PartialTXN + +""" +Yet to add: + Converter configuration. + Each converter can provide its own configuration class, which is read from the main configuration file when the converter is used. + This provides such things as the fee account, the main expense account, the name to look for... + self.config.fee_account + config_section = 'CBA Converter' + +Notes: + Different currencies may come from different expense accounts.. + The logic is actually complex enough that the configuration may as well be in code, honestly. + Or not.. we can just have placeholders filled in for the defaults. +""" + +_SKIPT = Enum('SKIPT', (('SKIP', 'SKIP'),)) +SKIP = _SKIPT.SKIP + + +class ConverterConfig: + """ + Base class for converter configuration. + """ + @classmethod + def from_dict(cls, data: dict) -> Self: + """ + Load configuration for a serialised dictionary. + + The dictionary/MappingProxy will usually be e.g. a configuration section. + """ + raise NotImplementedError + + def to_dict(self) -> dict: + """ + Dump configuration to a format appropriate for serialisation. + + Must be the inverse of from_dict. + """ + raise NotImplementedError + + +class Converter[RecordT: Record, ConfigT: ConverterConfig]: + """ + ABC for Record -> Transaction conversion interface. + + A Converter defines and controls: + - How input is parsed into Records + - How Records are converted to Transactions (in addition to the RuleSet). + """ + converter_name: str + version: str + display_name: str + config_field: str + + record_type: ClassVar[Type[RecordT]] + config_type: ClassVar[Type[ConfigT]] + + def __init__(self, config: ConfigT, **kwargs): + self.config = config + + @classmethod + def qual_name(cls): + return f"{cls.converter_name}_v{cls.version}" + + def annotation(self, record: RecordT, partial: PartialTXN) -> Optional[str]: + """ + Optional user-readable note/warning to attach to a mapped record. + + Intended to show e.g. missing fields or invalid transactions. + """ + raise NotImplementedError + + def convert(self, record: RecordT, ruleset: RuleSet) -> PartialTXN | _SKIPT: + """ + The meat of the conversion process. + Take a raw Record from the data and convert it into a partial transaction, + or the SKIP sentinel value. + + The returned partial transaction may not be complete, + thus not convertible to a full Transaction. + + This method must be abstract because individual converters define their + own record-logic. + + Users may either edit the resulting TXN directly to complete it, + or provide a rule linking the record fields to the transaction fields. + In the latter case, this method will be re-run to build the transaction. + """ + raise NotImplementedError + + @classmethod + def ingest_string(cls, data: str) -> list[RecordT]: + """ + Parse a string of raw input into a list of records. + """ + raise NotImplementedError + + @classmethod + def ingest_file(cls, path) -> list[RecordT]: + """ + Ingest a target file (from path) into Records. + """ + with open(path, 'r') as f: + return cls.ingest_string(f.read()) diff --git a/src/base/partial.py b/src/base/partial.py new file mode 100644 index 0000000..f6b3bda --- /dev/null +++ b/src/base/partial.py @@ -0,0 +1,269 @@ +from typing import NamedTuple, Optional +from dataclasses import dataclass, replace +import datetime as dt + +from .transaction import ABCPosting, Transaction, TXNPosting, TXNFlag + + +class UserInputError(Exception): + def __init__(self, msg: str): + super().__init__(msg) + self.message = msg + + +@dataclass(slots=True, kw_only=True) +class PartialPosting(ABCPosting): + """ + Partial posting object, potentially without an account name. + """ + account: Optional[str] = None + + def upgrade(self, default_account: Optional[str] = None) -> TXNPosting: + account = self.account if self.account is not None else default_account + if account is None: + raise ValueError("PartialPosting has no account set, cannot upgrade.") + + return TXNPosting( + account=account.format(currency=self.amount.currency), + amount=self.amount, + cost=self.cost, + total_cost=self.total_cost, + price=self.price, + flag=self.flag, + comment=self.comment + ) + + @property + def partial(self): + return self.account is None + + +class TXNField(NamedTuple): + name: str + display_name: str + value: str + matchable: bool + options: tuple[str] | None = None + + +@dataclass(kw_only=True) +class PartialTXN: + """ + Represents a simplified fixed-form Beancount Transaction, + with strong assumptions on the transaction type. + + TODO: REPR + """ + date: dt.date + flag: TXNFlag = TXNFlag.INCOMPLETE + payee: str = '' + narration: str = '' + comment: Optional[str] = None + document: Optional[str] = None + tags: str = '' + links: str = '' + source_posting: PartialPosting + source_fee_asset_posting: Optional[PartialPosting] = None + source_fee_expense_posting: Optional[PartialPosting] = None + target_posting: PartialPosting + target_fee_expense_posting: Optional[PartialPosting] = None + + # Exposing set of fields which may be updated (e.g. from rules) + # Map field name -> display name + fields = { + 'flag': 'Flag', + 'payee': 'Payee', + 'narration': 'Narration', + 'comment': "Comment", + 'document': "Document", + 'tags': "Tags", + 'links': "Links", + 'source_account': "Source Account", + 'source_fee_asset_account': "Source Fee Asset Account", + 'source_fee_expense_account': "Source Fee Expense Account", + 'target_account': "Target Account", + 'target_fee_expense_account': "Target Fee Expense Account", + } + posting_fields = { + 'source_posting': 'source_account', + 'source_fee_asset_posting': 'source_fee_asset_account', + 'source_fee_expense_posting': 'source_fee_expense_account', + 'target_posting': 'target_account', + 'target_fee_expense_posting': 'target_fee_expense_account', + } + + @property + def source_account(self): + return self.source_posting.account + + @source_account.setter + def source_account(self, value: str): + self.source_posting.account = value + + @property + def target_account(self): + return self.target_posting.account + + @target_account.setter + def target_account(self, value: str): + self.target_posting.account = value + + @property + def source_fee_asset_account(self): + if (posting := self.source_fee_asset_posting) is not None: + return posting.account + + @source_fee_asset_account.setter + def source_fee_asset_account(self, value: str): + if (posting := self.source_fee_asset_posting) is not None: + posting.account = value + else: + raise ValueError("This TXN does not have a source fee asset posting to set.") + + @property + def source_fee_expense_account(self): + if (posting := self.source_fee_expense_posting) is not None: + return posting.account + + @source_fee_expense_account.setter + def source_fee_expense_account(self, value: str): + if (posting := self.source_fee_expense_posting) is not None: + posting.account = value + else: + raise ValueError("This TXN does not have a source fee expense posting to set.") + + @property + def target_fee_expense_account(self): + if (posting := self.target_fee_expense_posting) is not None: + return posting.account + + @target_fee_expense_account.setter + def target_fee_expense_account(self, value: str): + if (posting := self.target_fee_expense_posting) is not None: + posting.account = value + else: + raise ValueError("This TXN does not have a target fee expense posting to set.") + + @property + def postings(self): + postings = {} + for name in self.posting_fields: + posting = getattr(self, name) + if posting is not None: + postings[name] = posting + return postings + + @property + def partial(self): + return any(posting.partial for posting in self.postings.values()) + + def update(self, overwrite=True, **kwargs): + """ + Update TXN from provided field values. + + With overwrite=False, only modifes fields that have not been set. + """ + for field in self.fields: + if field in kwargs and (overwrite or not getattr(self, field)): + # Note that this will error if we attempt to set + # the account name for a posting we don't have. + setattr(self, field, kwargs[field]) + + def upgrade(self, defaults={}) -> Transaction: + """ + Upgrade this PartialTransaction to a full Transaction using the given default fields if needed. + """ + if self.partial and defaults: + with_defaults = self.copy() + # Remove defaults for postings we don't have + we_have = self.postings.keys() + we_dont_have = set(self.posting_fields.keys()).difference(we_have) + with_defaults.update( + overwrite=False, + **{k: v for k, v in defaults.items() if k not in we_dont_have} + ) + upgraded = with_defaults.upgrade() + elif self.partial: + raise ValueError("Cannot upgrade partial transaction.") + else: + upgraded = Transaction( + date=self.date, + flag=self.flag, + payee=self.payee, + narration=self.narration, + comments=[self.comment] if self.comment else [], + document=[self.document] if self.document else [], + tags=self.tags.split(), + links=self.links.split(), + postings=[p.upgrade() for p in self.postings.values()] + ) + return upgraded + + def copy(self): + update = {} + for posting_name in self.posting_fields: + posting = getattr(self, posting_name) + if posting is not None: + update[posting_name] = replace(posting) + return replace(self, **update) + + def display_fields(self) -> list[TXNField]: + """ + The fields to display from this partial transaction. + """ + fields = [] + postings = self.postings + field_postings = {n: pn for pn, n in self.posting_fields.items()} + for name, display_name in self.fields.items(): + if (pname := field_postings.get(name)) and pname not in postings: + # Don't include posting accounts which aren't there + continue + match name: + case 'flag': + value = self.flag.value + case _: + value = getattr(self, name) + value = str(value) if value is not None else '' + fields.append(TXNField( + name=name, + display_name=display_name, + value=value, + matchable=True + )) + + return fields + + def parse_input(self, entries: dict[str, str]): + """ + Parse a map of field name -> user entered strings + into a dictionary which may be used in update() + """ + updater = {} + for name, userstr in entries.items(): + userstr = userstr.strip() + # TODO: Each of these cases needs custom validation + match name: + case 'flag': + if userstr == '!': + updater['flag'] = TXNFlag.INCOMPLETE + elif userstr == '*': + updater['flag'] = TXNFlag.COMPLETE + else: + raise UserInputError( + "Transaction flag must be either '*' or '!'" + ) + case 'payee' | 'narration' | 'tags' | 'links': + updater[name] = userstr + case 'comment' | 'document': + updater[name] = userstr or None + case 'source_account' | 'target_account': + updater[name] = userstr + case 'source_fee_asset_account': + updater[name] = userstr + case 'source_fee_expense_account': + updater[name] = userstr + case 'target_fee_expense_account': + updater[name] = userstr + case _: + raise ValueError(f"Unknown field {name} passed to TXN parser.") + return updater diff --git a/src/base/record.py b/src/base/record.py new file mode 100644 index 0000000..1af8d35 --- /dev/null +++ b/src/base/record.py @@ -0,0 +1,96 @@ +from typing import ClassVar, NamedTuple +from dataclasses import dataclass, field +import datetime as dt + +from . import Amount + + +class RecordField(NamedTuple): + name: str + display_name: str + value: str + matchable: bool + + +@dataclass(kw_only=True, frozen=True) +class Record: + """ + Represents a raw transaction record read by a converter from input. + + This formalises and fixes the data structure to convert, + so that the rules and converter remain unchanged even if the underlying data format + (e.g. bank statement csv) changes. + + Specific converters should subclass the Record for any required custom fields + and context. + + The Record fields will be provided and checked for the conditional fields of Rules. + """ + date: dt.date + source_account: str + target_account: str + + from_source: Amount + to_target: Amount + + fees: tuple[Amount] = field(default_factory=tuple) + raw: str | None = None + comments: tuple[str] = field(default_factory=tuple) + + # List of record fields to display in UI + # List of [field name, display name] + _display_fields: ClassVar[list[tuple[str, str]]] + + # List of record fields which may be included in rules + _match_fields: ClassVar[list[str]] + + @classmethod + def sample_record(cls): + raise NotImplementedError + + def display_fields(self) -> list[RecordField]: + """ + Build the fields to display from this record. + + Default implementation uses the _display_fields and _match_fields + class variables. + May be overidden by subclasses which need more complex + display logic. + """ + fields = [] + for name, display in self._display_fields: + value = getattr(self, name) + value = str(value) if value is not None else '' + matchable = name in self._match_fields + fields.append(RecordField(name, display, value, matchable)) + + return fields + + def match_fields(self) -> dict[str, str]: + """ + Build the field: value pairs for matching against a Rule. + + Default implementation uses the _match_fields classvar. + May be overridden by subclasses which provide more match variables, + e.g. variables which are not attributes. + """ + field_values = {} + for field_name in self._match_fields: + value = getattr(self, field_name) + # Replace None values by empty strings + value = str(value) if value is not None else '' + field_values[field_name] = value + + return field_values + + def display_table(self): + """ + Tabulate the record for a simple display. + """ + fields = self.display_fields() + maxlen = max((len(f.display_name) for f in fields), default=0) + lines = [] + for _, display, value, _ in fields: + lines.append(f"{display:<{maxlen}}:\t{value}") + + return lines diff --git a/src/base/rules.py b/src/base/rules.py index e69de29..c6006a3 100644 --- a/src/base/rules.py +++ b/src/base/rules.py @@ -0,0 +1,135 @@ +from typing import Self +import os + +class Rule: + __slots__ = ('conditions', 'values') + + def __init__(self, conditions: dict[str, str], values: dict[str, str]): + self.conditions = conditions + self.values = values + + def check(self, record: dict[str, str]) -> bool: + """ + Check whether this rule applies to the given record fields. + """ + return all(record.get(key, None) == value for key, value in self.conditions.items()) + + +class RuleInterface: + """ + ABC for Record -> Transaction rule data interface. + """ + + def __init__(self, converter: str, **kwargs): + self.converter = converter + + def load_rules(self) -> list[Rule]: + raise NotImplementedError + + def save_rules(self, rules: list[Rule]): + """ + Save the given rules to storage. + """ + raise NotImplementedError + + +class JSONRuleInterface(RuleInterface): + """ + Serialise rules into and out of a JSON file. + + Schema: + { + 'rules': [ + { + 'record_fields': {}, + 'transaction_fields': {}, + } + ] + } + """ + def __init__(self, converter: str, path: str, **kwargs): + self.path = path + super().__init__(converter, **kwargs) + + def load_rules(self) -> list[Rule]: + import json + + rules = [] + if not os.path.exists(self.path): + self.save_rules([]) + + with open(self.path, 'r') as f: + data = json.load(f) + for rule_data in data.get('rules', []): + rule = Rule( + conditions=rule_data['record_fields'], + values=rule_data['transaction_fields'], + ) + rules.append(rule) + + return rules + + def save_rules(self, rules: list[Rule]): + import json + + rule_data = [] + for rule in rules: + rule_data.append({ + 'record_fields': rule.conditions, + 'transaction_fields': rule.values, + }) + data = json.dumps({'rules': rule_data}, indent=2) + with open(self.path, 'w') as f: + f.write(data) + + +class DummyRuleInterface(RuleInterface): + """ + Dummy plug for the rule interface. + Can be used for testing or if the rules are otherwise loaded internally. + """ + def load_rules(self): + return [] + + def save_rules(self, rules): + return + + +class RuleSet: + def __init__(self, rules: list[Rule], interface: RuleInterface): + self.rules = rules + self.interface = interface + + @classmethod + def load_from(cls, interface: RuleInterface) -> Self: + rules = interface.load_rules() + return cls(rules, interface) + + def reload_rules(self): + self.rules = self.interface.load_rules() + + def save_rules(self): + self.interface.save_rules(self.rules) + + def apply(self, record_fields: dict[str, str]) -> dict[str, str]: + """ + Apply the ruleset to the given record. + Returns a dictionary of partial-transaction fields. + + Transaction fields may be empty if no rules apply. + """ + result = {} + + # TODO: Come up with some nice field cache/map for efficient rule lookup. + # Probably build a rule tree + for rule in self.rules: + if rule.check(record_fields): + result |= rule.values + + return result + + def add_rule(self, rule: Rule): + """ + Add a rule to the rule set. + """ + self.rules.append(rule) diff --git a/src/base/transaction.py b/src/base/transaction.py new file mode 100644 index 0000000..a9d54ac --- /dev/null +++ b/src/base/transaction.py @@ -0,0 +1,164 @@ +from typing import ClassVar, Optional +import datetime as dt +from datetime import datetime +from enum import Enum +from dataclasses import dataclass + + +class TXNFlag(Enum): + COMPLETE = '*' + INCOMPLETE = '!' + +@dataclass(slots=True, frozen=True) +class Amount: + """ + Represents a beancount 'amount' with given currency. + """ + value: float + currency: str + + def __str__(self): + return f"{self.value} {self.currency}" + + def at_price(self, price: 'Amount') -> 'Amount': + return Amount(self.value * price.value, price.currency) + + def __add__(self, other): + if isinstance(other, Amount): + if other.currency != self.currency: + raise ValueError("Cannot add Amounts with different currencies.") + return Amount(self.value + other.value, self.currency) + else: + return NotImplemented + + def __sub__(self, other): + if isinstance(other, Amount): + if other.currency != self.currency: + raise ValueError("Cannot subtract Amounts with different currencies.") + return Amount(self.value - other.value, self.currency) + else: + return NotImplemented + + def __rsub__(self, other): + return self - other + + def __hash__(self): + return hash((self.value, self.currency)) + + def __neg__(self): + return Amount(-self.value, self.currency) + +@dataclass(slots=True, kw_only=True) +class ABCPosting: + """ + Represents the data of a TXNPosting. + """ + amount: Amount + cost: Optional[Amount] = None + total_cost: Optional[Amount] = None + price: Optional[Amount] = None + flag: Optional[TXNFlag] = None + comment: Optional[str] = None + + def __post_init__(self): + """Validation""" + if self.total_cost is not None and self.price is not None: + raise ValueError("Posting cannot have both price and total cost") + + def weight(self) -> Amount: + """ + Return the balancing weight of this posting. + + Implementation of: + https://beancount.github.io/docs/beancount_language_syntax.html#balancing-rule-the-weight-of-postings + """ + if self.cost is not None: + weight = self.amount.at_price(self.cost) + elif self.total_cost is not None: + weight = self.total_cost + elif self.price is not None: + weight = self.amount.at_price(self.price) + else: + weight = self.amount + return weight + + +@dataclass(slots=True, kw_only=True) +class TXNPosting(ABCPosting): + """ + Represents a single row of a Transaction + + [Flag] Account Amount [{Cost}] [@ Price] + + Note: Remember that Cost, Price, and Total Price are unsigned. + """ + account: str + def __str__(self): + parts = [] + if self.flag: + parts.append(self.flag.value) + parts.append(self.account) + parts.append(str(self.amount)) + if self.cost is not None: + parts.append(f"{{{self.cost}}}") + if self.price is not None: + parts.append(f"@ {self.price}") + if self.total_cost is not None: + parts.append(f"@@ {self.total_cost}") + + if self.comment: + parts.append(f" ; {self.comment}") + + return ' '.join(parts) + + +class Transaction: + """ + Represents a BeanCount Transaction. + """ + + def __init__(self, date, **kwargs): + self.date: dt.date = date + self.flag: TXNFlag = kwargs.get('flag', TXNFlag.COMPLETE) + self.payee: str = kwargs.get('payee', "") + self.narration: str = kwargs.get('narration', "") + self.comments: list[str] = kwargs.get('comments', []) + self.documents: list[str] = kwargs.get('documents', []) + self.tags: list[str] = kwargs.get('tags', []) + self.links: list[str] = kwargs.get('links', []) + + self.postings: list[TXNPosting] = kwargs.get('postings', []) + + def check(self) -> bool: + """ + Check whether this transaction balances. + """ + return sum((posting.weight() for posting in self.postings), 0) == 0 + + def __str__(self): + """ + Beancount ledger representation of this transaction. + """ + header = "{date} {flag} {payee} {narration} {tags} {links}".format( + date=self.date, + flag=self.flag.value, + payee=f"\"{self.payee or ''}\"", + narration=f"\"{self.narration or ''}\"", + tags=' '.join('#' + tag for tag in self.tags), + links=' '.join('^' + link for link in self.links), + ).strip() + + lines = [] + for comment in self.comments: + lines.append("; " + comment) + + for document in self.documents: + lines.append("document: " + document) + + for posting in self.postings: + lines.append(str(posting)) + + return '\n'.join((header, *(' ' + line for line in lines))) + + +