Add ABCs for records and transactions

This commit is contained in:
2025-12-02 16:17:02 +10:00
parent 527c26ba49
commit af5a827710
6 changed files with 778 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
from .rules import RuleSet, Rule, RuleInterface
from .transaction import Transaction, Amount, TXNFlag, TXNPosting
from .record import Record
from .converter import Converter
# from .beancounter import BeanCounter
from .partial import PartialPosting, PartialTXN

View File

@@ -0,0 +1,108 @@
from enum import Enum
from typing import ClassVar, Literal, Self, Type, Optional
from . import RuleSet, Record, Transaction
from .partial import PartialTXN
"""
Yet to add:
Converter configuration.
Each converter can provide its own configuration class, which is read from the main configuration file when the converter is used.
This provides such things as the fee account, the main expense account, the name to look for...
self.config.fee_account
config_section = 'CBA Converter'
Notes:
Different currencies may come from different expense accounts..
The logic is actually complex enough that the configuration may as well be in code, honestly.
Or not.. we can just have placeholders filled in for the defaults.
"""
_SKIPT = Enum('SKIPT', (('SKIP', 'SKIP'),))
SKIP = _SKIPT.SKIP
class ConverterConfig:
"""
Base class for converter configuration.
"""
@classmethod
def from_dict(cls, data: dict) -> Self:
"""
Load configuration for a serialised dictionary.
The dictionary/MappingProxy will usually be e.g. a configuration section.
"""
raise NotImplementedError
def to_dict(self) -> dict:
"""
Dump configuration to a format appropriate for serialisation.
Must be the inverse of from_dict.
"""
raise NotImplementedError
class Converter[RecordT: Record, ConfigT: ConverterConfig]:
"""
ABC for Record -> Transaction conversion interface.
A Converter defines and controls:
- How input is parsed into Records
- How Records are converted to Transactions (in addition to the RuleSet).
"""
converter_name: str
version: str
display_name: str
config_field: str
record_type: ClassVar[Type[RecordT]]
config_type: ClassVar[Type[ConfigT]]
def __init__(self, config: ConfigT, **kwargs):
self.config = config
@classmethod
def qual_name(cls):
return f"{cls.converter_name}_v{cls.version}"
def annotation(self, record: RecordT, partial: PartialTXN) -> Optional[str]:
"""
Optional user-readable note/warning to attach to a mapped record.
Intended to show e.g. missing fields or invalid transactions.
"""
raise NotImplementedError
def convert(self, record: RecordT, ruleset: RuleSet) -> PartialTXN | _SKIPT:
"""
The meat of the conversion process.
Take a raw Record from the data and convert it into a partial transaction,
or the SKIP sentinel value.
The returned partial transaction may not be complete,
thus not convertible to a full Transaction.
This method must be abstract because individual converters define their
own record-logic.
Users may either edit the resulting TXN directly to complete it,
or provide a rule linking the record fields to the transaction fields.
In the latter case, this method will be re-run to build the transaction.
"""
raise NotImplementedError
@classmethod
def ingest_string(cls, data: str) -> list[RecordT]:
"""
Parse a string of raw input into a list of records.
"""
raise NotImplementedError
@classmethod
def ingest_file(cls, path) -> list[RecordT]:
"""
Ingest a target file (from path) into Records.
"""
with open(path, 'r') as f:
return cls.ingest_string(f.read())

269
src/base/partial.py Normal file
View File

@@ -0,0 +1,269 @@
from typing import NamedTuple, Optional
from dataclasses import dataclass, replace
import datetime as dt
from .transaction import ABCPosting, Transaction, TXNPosting, TXNFlag
class UserInputError(Exception):
def __init__(self, msg: str):
super().__init__(msg)
self.message = msg
@dataclass(slots=True, kw_only=True)
class PartialPosting(ABCPosting):
"""
Partial posting object, potentially without an account name.
"""
account: Optional[str] = None
def upgrade(self, default_account: Optional[str] = None) -> TXNPosting:
account = self.account if self.account is not None else default_account
if account is None:
raise ValueError("PartialPosting has no account set, cannot upgrade.")
return TXNPosting(
account=account.format(currency=self.amount.currency),
amount=self.amount,
cost=self.cost,
total_cost=self.total_cost,
price=self.price,
flag=self.flag,
comment=self.comment
)
@property
def partial(self):
return self.account is None
class TXNField(NamedTuple):
name: str
display_name: str
value: str
matchable: bool
options: tuple[str] | None = None
@dataclass(kw_only=True)
class PartialTXN:
"""
Represents a simplified fixed-form Beancount Transaction,
with strong assumptions on the transaction type.
TODO: REPR
"""
date: dt.date
flag: TXNFlag = TXNFlag.INCOMPLETE
payee: str = ''
narration: str = ''
comment: Optional[str] = None
document: Optional[str] = None
tags: str = ''
links: str = ''
source_posting: PartialPosting
source_fee_asset_posting: Optional[PartialPosting] = None
source_fee_expense_posting: Optional[PartialPosting] = None
target_posting: PartialPosting
target_fee_expense_posting: Optional[PartialPosting] = None
# Exposing set of fields which may be updated (e.g. from rules)
# Map field name -> display name
fields = {
'flag': 'Flag',
'payee': 'Payee',
'narration': 'Narration',
'comment': "Comment",
'document': "Document",
'tags': "Tags",
'links': "Links",
'source_account': "Source Account",
'source_fee_asset_account': "Source Fee Asset Account",
'source_fee_expense_account': "Source Fee Expense Account",
'target_account': "Target Account",
'target_fee_expense_account': "Target Fee Expense Account",
}
posting_fields = {
'source_posting': 'source_account',
'source_fee_asset_posting': 'source_fee_asset_account',
'source_fee_expense_posting': 'source_fee_expense_account',
'target_posting': 'target_account',
'target_fee_expense_posting': 'target_fee_expense_account',
}
@property
def source_account(self):
return self.source_posting.account
@source_account.setter
def source_account(self, value: str):
self.source_posting.account = value
@property
def target_account(self):
return self.target_posting.account
@target_account.setter
def target_account(self, value: str):
self.target_posting.account = value
@property
def source_fee_asset_account(self):
if (posting := self.source_fee_asset_posting) is not None:
return posting.account
@source_fee_asset_account.setter
def source_fee_asset_account(self, value: str):
if (posting := self.source_fee_asset_posting) is not None:
posting.account = value
else:
raise ValueError("This TXN does not have a source fee asset posting to set.")
@property
def source_fee_expense_account(self):
if (posting := self.source_fee_expense_posting) is not None:
return posting.account
@source_fee_expense_account.setter
def source_fee_expense_account(self, value: str):
if (posting := self.source_fee_expense_posting) is not None:
posting.account = value
else:
raise ValueError("This TXN does not have a source fee expense posting to set.")
@property
def target_fee_expense_account(self):
if (posting := self.target_fee_expense_posting) is not None:
return posting.account
@target_fee_expense_account.setter
def target_fee_expense_account(self, value: str):
if (posting := self.target_fee_expense_posting) is not None:
posting.account = value
else:
raise ValueError("This TXN does not have a target fee expense posting to set.")
@property
def postings(self):
postings = {}
for name in self.posting_fields:
posting = getattr(self, name)
if posting is not None:
postings[name] = posting
return postings
@property
def partial(self):
return any(posting.partial for posting in self.postings.values())
def update(self, overwrite=True, **kwargs):
"""
Update TXN from provided field values.
With overwrite=False, only modifes fields that have not been set.
"""
for field in self.fields:
if field in kwargs and (overwrite or not getattr(self, field)):
# Note that this will error if we attempt to set
# the account name for a posting we don't have.
setattr(self, field, kwargs[field])
def upgrade(self, defaults={}) -> Transaction:
"""
Upgrade this PartialTransaction to a full Transaction using the given default fields if needed.
"""
if self.partial and defaults:
with_defaults = self.copy()
# Remove defaults for postings we don't have
we_have = self.postings.keys()
we_dont_have = set(self.posting_fields.keys()).difference(we_have)
with_defaults.update(
overwrite=False,
**{k: v for k, v in defaults.items() if k not in we_dont_have}
)
upgraded = with_defaults.upgrade()
elif self.partial:
raise ValueError("Cannot upgrade partial transaction.")
else:
upgraded = Transaction(
date=self.date,
flag=self.flag,
payee=self.payee,
narration=self.narration,
comments=[self.comment] if self.comment else [],
document=[self.document] if self.document else [],
tags=self.tags.split(),
links=self.links.split(),
postings=[p.upgrade() for p in self.postings.values()]
)
return upgraded
def copy(self):
update = {}
for posting_name in self.posting_fields:
posting = getattr(self, posting_name)
if posting is not None:
update[posting_name] = replace(posting)
return replace(self, **update)
def display_fields(self) -> list[TXNField]:
"""
The fields to display from this partial transaction.
"""
fields = []
postings = self.postings
field_postings = {n: pn for pn, n in self.posting_fields.items()}
for name, display_name in self.fields.items():
if (pname := field_postings.get(name)) and pname not in postings:
# Don't include posting accounts which aren't there
continue
match name:
case 'flag':
value = self.flag.value
case _:
value = getattr(self, name)
value = str(value) if value is not None else ''
fields.append(TXNField(
name=name,
display_name=display_name,
value=value,
matchable=True
))
return fields
def parse_input(self, entries: dict[str, str]):
"""
Parse a map of field name -> user entered strings
into a dictionary which may be used in update()
"""
updater = {}
for name, userstr in entries.items():
userstr = userstr.strip()
# TODO: Each of these cases needs custom validation
match name:
case 'flag':
if userstr == '!':
updater['flag'] = TXNFlag.INCOMPLETE
elif userstr == '*':
updater['flag'] = TXNFlag.COMPLETE
else:
raise UserInputError(
"Transaction flag must be either '*' or '!'"
)
case 'payee' | 'narration' | 'tags' | 'links':
updater[name] = userstr
case 'comment' | 'document':
updater[name] = userstr or None
case 'source_account' | 'target_account':
updater[name] = userstr
case 'source_fee_asset_account':
updater[name] = userstr
case 'source_fee_expense_account':
updater[name] = userstr
case 'target_fee_expense_account':
updater[name] = userstr
case _:
raise ValueError(f"Unknown field {name} passed to TXN parser.")
return updater

96
src/base/record.py Normal file
View File

@@ -0,0 +1,96 @@
from typing import ClassVar, NamedTuple
from dataclasses import dataclass, field
import datetime as dt
from . import Amount
class RecordField(NamedTuple):
name: str
display_name: str
value: str
matchable: bool
@dataclass(kw_only=True, frozen=True)
class Record:
"""
Represents a raw transaction record read by a converter from input.
This formalises and fixes the data structure to convert,
so that the rules and converter remain unchanged even if the underlying data format
(e.g. bank statement csv) changes.
Specific converters should subclass the Record for any required custom fields
and context.
The Record fields will be provided and checked for the conditional fields of Rules.
"""
date: dt.date
source_account: str
target_account: str
from_source: Amount
to_target: Amount
fees: tuple[Amount] = field(default_factory=tuple)
raw: str | None = None
comments: tuple[str] = field(default_factory=tuple)
# List of record fields to display in UI
# List of [field name, display name]
_display_fields: ClassVar[list[tuple[str, str]]]
# List of record fields which may be included in rules
_match_fields: ClassVar[list[str]]
@classmethod
def sample_record(cls):
raise NotImplementedError
def display_fields(self) -> list[RecordField]:
"""
Build the fields to display from this record.
Default implementation uses the _display_fields and _match_fields
class variables.
May be overidden by subclasses which need more complex
display logic.
"""
fields = []
for name, display in self._display_fields:
value = getattr(self, name)
value = str(value) if value is not None else ''
matchable = name in self._match_fields
fields.append(RecordField(name, display, value, matchable))
return fields
def match_fields(self) -> dict[str, str]:
"""
Build the field: value pairs for matching against a Rule.
Default implementation uses the _match_fields classvar.
May be overridden by subclasses which provide more match variables,
e.g. variables which are not attributes.
"""
field_values = {}
for field_name in self._match_fields:
value = getattr(self, field_name)
# Replace None values by empty strings
value = str(value) if value is not None else ''
field_values[field_name] = value
return field_values
def display_table(self):
"""
Tabulate the record for a simple display.
"""
fields = self.display_fields()
maxlen = max((len(f.display_name) for f in fields), default=0)
lines = []
for _, display, value, _ in fields:
lines.append(f"{display:<{maxlen}}:\t{value}")
return lines

View File

@@ -0,0 +1,135 @@
from typing import Self
import os
class Rule:
__slots__ = ('conditions', 'values')
def __init__(self, conditions: dict[str, str], values: dict[str, str]):
self.conditions = conditions
self.values = values
def check(self, record: dict[str, str]) -> bool:
"""
Check whether this rule applies to the given record fields.
"""
return all(record.get(key, None) == value for key, value in self.conditions.items())
class RuleInterface:
"""
ABC for Record -> Transaction rule data interface.
"""
def __init__(self, converter: str, **kwargs):
self.converter = converter
def load_rules(self) -> list[Rule]:
raise NotImplementedError
def save_rules(self, rules: list[Rule]):
"""
Save the given rules to storage.
"""
raise NotImplementedError
class JSONRuleInterface(RuleInterface):
"""
Serialise rules into and out of a JSON file.
Schema:
{
'rules': [
{
'record_fields': {},
'transaction_fields': {},
}
]
}
"""
def __init__(self, converter: str, path: str, **kwargs):
self.path = path
super().__init__(converter, **kwargs)
def load_rules(self) -> list[Rule]:
import json
rules = []
if not os.path.exists(self.path):
self.save_rules([])
with open(self.path, 'r') as f:
data = json.load(f)
for rule_data in data.get('rules', []):
rule = Rule(
conditions=rule_data['record_fields'],
values=rule_data['transaction_fields'],
)
rules.append(rule)
return rules
def save_rules(self, rules: list[Rule]):
import json
rule_data = []
for rule in rules:
rule_data.append({
'record_fields': rule.conditions,
'transaction_fields': rule.values,
})
data = json.dumps({'rules': rule_data}, indent=2)
with open(self.path, 'w') as f:
f.write(data)
class DummyRuleInterface(RuleInterface):
"""
Dummy plug for the rule interface.
Can be used for testing or if the rules are otherwise loaded internally.
"""
def load_rules(self):
return []
def save_rules(self, rules):
return
class RuleSet:
def __init__(self, rules: list[Rule], interface: RuleInterface):
self.rules = rules
self.interface = interface
@classmethod
def load_from(cls, interface: RuleInterface) -> Self:
rules = interface.load_rules()
return cls(rules, interface)
def reload_rules(self):
self.rules = self.interface.load_rules()
def save_rules(self):
self.interface.save_rules(self.rules)
def apply(self, record_fields: dict[str, str]) -> dict[str, str]:
"""
Apply the ruleset to the given record.
Returns a dictionary of partial-transaction fields.
Transaction fields may be empty if no rules apply.
"""
result = {}
# TODO: Come up with some nice field cache/map for efficient rule lookup.
# Probably build a rule tree
for rule in self.rules:
if rule.check(record_fields):
result |= rule.values
return result
def add_rule(self, rule: Rule):
"""
Add a rule to the rule set.
"""
self.rules.append(rule)

164
src/base/transaction.py Normal file
View File

@@ -0,0 +1,164 @@
from typing import ClassVar, Optional
import datetime as dt
from datetime import datetime
from enum import Enum
from dataclasses import dataclass
class TXNFlag(Enum):
COMPLETE = '*'
INCOMPLETE = '!'
@dataclass(slots=True, frozen=True)
class Amount:
"""
Represents a beancount 'amount' with given currency.
"""
value: float
currency: str
def __str__(self):
return f"{self.value} {self.currency}"
def at_price(self, price: 'Amount') -> 'Amount':
return Amount(self.value * price.value, price.currency)
def __add__(self, other):
if isinstance(other, Amount):
if other.currency != self.currency:
raise ValueError("Cannot add Amounts with different currencies.")
return Amount(self.value + other.value, self.currency)
else:
return NotImplemented
def __sub__(self, other):
if isinstance(other, Amount):
if other.currency != self.currency:
raise ValueError("Cannot subtract Amounts with different currencies.")
return Amount(self.value - other.value, self.currency)
else:
return NotImplemented
def __rsub__(self, other):
return self - other
def __hash__(self):
return hash((self.value, self.currency))
def __neg__(self):
return Amount(-self.value, self.currency)
@dataclass(slots=True, kw_only=True)
class ABCPosting:
"""
Represents the data of a TXNPosting.
"""
amount: Amount
cost: Optional[Amount] = None
total_cost: Optional[Amount] = None
price: Optional[Amount] = None
flag: Optional[TXNFlag] = None
comment: Optional[str] = None
def __post_init__(self):
"""Validation"""
if self.total_cost is not None and self.price is not None:
raise ValueError("Posting cannot have both price and total cost")
def weight(self) -> Amount:
"""
Return the balancing weight of this posting.
Implementation of:
https://beancount.github.io/docs/beancount_language_syntax.html#balancing-rule-the-weight-of-postings
"""
if self.cost is not None:
weight = self.amount.at_price(self.cost)
elif self.total_cost is not None:
weight = self.total_cost
elif self.price is not None:
weight = self.amount.at_price(self.price)
else:
weight = self.amount
return weight
@dataclass(slots=True, kw_only=True)
class TXNPosting(ABCPosting):
"""
Represents a single row of a Transaction
[Flag] Account Amount [{Cost}] [@ Price]
Note: Remember that Cost, Price, and Total Price are unsigned.
"""
account: str
def __str__(self):
parts = []
if self.flag:
parts.append(self.flag.value)
parts.append(self.account)
parts.append(str(self.amount))
if self.cost is not None:
parts.append(f"{{{self.cost}}}")
if self.price is not None:
parts.append(f"@ {self.price}")
if self.total_cost is not None:
parts.append(f"@@ {self.total_cost}")
if self.comment:
parts.append(f" ; {self.comment}")
return ' '.join(parts)
class Transaction:
"""
Represents a BeanCount Transaction.
"""
def __init__(self, date, **kwargs):
self.date: dt.date = date
self.flag: TXNFlag = kwargs.get('flag', TXNFlag.COMPLETE)
self.payee: str = kwargs.get('payee', "")
self.narration: str = kwargs.get('narration', "")
self.comments: list[str] = kwargs.get('comments', [])
self.documents: list[str] = kwargs.get('documents', [])
self.tags: list[str] = kwargs.get('tags', [])
self.links: list[str] = kwargs.get('links', [])
self.postings: list[TXNPosting] = kwargs.get('postings', [])
def check(self) -> bool:
"""
Check whether this transaction balances.
"""
return sum((posting.weight() for posting in self.postings), 0) == 0
def __str__(self):
"""
Beancount ledger representation of this transaction.
"""
header = "{date} {flag} {payee} {narration} {tags} {links}".format(
date=self.date,
flag=self.flag.value,
payee=f"\"{self.payee or ''}\"",
narration=f"\"{self.narration or ''}\"",
tags=' '.join('#' + tag for tag in self.tags),
links=' '.join('^' + link for link in self.links),
).strip()
lines = []
for comment in self.comments:
lines.append("; " + comment)
for document in self.documents:
lines.append("document: " + document)
for posting in self.postings:
lines.append(str(posting))
return '\n'.join((header, *(' ' + line for line in lines)))