diff --git a/src/beanify/__init__.py b/src/beanify/__init__.py new file mode 100644 index 0000000..86966d1 --- /dev/null +++ b/src/beanify/__init__.py @@ -0,0 +1 @@ +from .main import run diff --git a/src/beanify/base/converter.py b/src/beanify/base/converter.py index 21c5ba6..583c83f 100644 --- a/src/beanify/base/converter.py +++ b/src/beanify/base/converter.py @@ -17,7 +17,7 @@ Notes: Or not.. we can just have placeholders filled in for the defaults. """ -_SKIPT = Enum('SKIPT', (('SKIP', 'SKIP'),)) +_SKIPT = Enum("SKIPT", (("SKIP", "SKIP"),)) SKIP = _SKIPT.SKIP @@ -25,6 +25,7 @@ class ConverterConfig: """ Base class for converter configuration. """ + @classmethod def from_dict(cls, data: dict) -> Self: """ @@ -51,6 +52,7 @@ class Converter[RecordT: Record, ConfigT: ConverterConfig]: - How input is parsed into Records - How Records are converted to Transactions (in addition to the RuleSet). """ + converter_name: str version: str display_name: str @@ -92,17 +94,15 @@ class Converter[RecordT: Record, ConfigT: ConverterConfig]: """ raise NotImplementedError - @classmethod - def ingest_string(cls, data: str) -> list[RecordT]: + def ingest_string(self, data: str) -> list[RecordT]: """ Parse a string of raw input into a list of records. """ raise NotImplementedError - @classmethod - def ingest_file(cls, path) -> list[RecordT]: + def ingest_file(self, path) -> list[RecordT]: """ Ingest a target file (from path) into Records. """ - with open(path, 'r') as f: - return cls.ingest_string(f.read()) + with open(path, "r") as f: + return self.ingest_string(f.read()) diff --git a/src/beanify/base/record.py b/src/beanify/base/record.py index 1af8d35..d29b1be 100644 --- a/src/beanify/base/record.py +++ b/src/beanify/base/record.py @@ -26,6 +26,7 @@ class Record: The Record fields will be provided and checked for the conditional fields of Rules. """ + date: dt.date source_account: str target_account: str @@ -33,9 +34,9 @@ class Record: from_source: Amount to_target: Amount - fees: tuple[Amount] = field(default_factory=tuple) + fees: tuple[Amount, ...] = field(default_factory=tuple) raw: str | None = None - comments: tuple[str] = field(default_factory=tuple) + comments: tuple[str, ...] = field(default_factory=tuple) # List of record fields to display in UI # List of [field name, display name] @@ -60,7 +61,7 @@ class Record: fields = [] for name, display in self._display_fields: value = getattr(self, name) - value = str(value) if value is not None else '' + value = str(value) if value is not None else "" matchable = name in self._match_fields fields.append(RecordField(name, display, value, matchable)) @@ -78,7 +79,7 @@ class Record: for field_name in self._match_fields: value = getattr(self, field_name) # Replace None values by empty strings - value = str(value) if value is not None else '' + value = str(value) if value is not None else "" field_values[field_name] = value return field_values diff --git a/src/beanify/base/transaction.py b/src/beanify/base/transaction.py index a9d54ac..baa8b0f 100644 --- a/src/beanify/base/transaction.py +++ b/src/beanify/base/transaction.py @@ -6,21 +6,23 @@ from dataclasses import dataclass class TXNFlag(Enum): - COMPLETE = '*' - INCOMPLETE = '!' + COMPLETE = "*" + INCOMPLETE = "!" + @dataclass(slots=True, frozen=True) class Amount: """ Represents a beancount 'amount' with given currency. """ + value: float currency: str def __str__(self): return f"{self.value} {self.currency}" - def at_price(self, price: 'Amount') -> 'Amount': + def at_price(self, price: "Amount") -> "Amount": return Amount(self.value * price.value, price.currency) def __add__(self, other): @@ -39,6 +41,9 @@ class Amount: else: return NotImplemented + def __eq__(self, other): + return (self.value == other.value) and (self.currency == other.currency) + def __rsub__(self, other): return self - other @@ -48,11 +53,16 @@ class Amount: def __neg__(self): return Amount(-self.value, self.currency) + def __abs__(self): + return Amount(abs(self.value), self.currency) + + @dataclass(slots=True, kw_only=True) class ABCPosting: """ Represents the data of a TXNPosting. """ + amount: Amount cost: Optional[Amount] = None total_cost: Optional[Amount] = None @@ -92,7 +102,9 @@ class TXNPosting(ABCPosting): Note: Remember that Cost, Price, and Total Price are unsigned. """ + account: str + def __str__(self): parts = [] if self.flag: @@ -109,7 +121,7 @@ class TXNPosting(ABCPosting): if self.comment: parts.append(f" ; {self.comment}") - return ' '.join(parts) + return " ".join(parts) class Transaction: @@ -119,15 +131,15 @@ class Transaction: def __init__(self, date, **kwargs): self.date: dt.date = date - self.flag: TXNFlag = kwargs.get('flag', TXNFlag.COMPLETE) - self.payee: str = kwargs.get('payee', "") - self.narration: str = kwargs.get('narration', "") - self.comments: list[str] = kwargs.get('comments', []) - self.documents: list[str] = kwargs.get('documents', []) - self.tags: list[str] = kwargs.get('tags', []) - self.links: list[str] = kwargs.get('links', []) + self.flag: TXNFlag = kwargs.get("flag", TXNFlag.COMPLETE) + self.payee: str = kwargs.get("payee", "") + self.narration: str = kwargs.get("narration", "") + self.comments: list[str] = kwargs.get("comments", []) + self.documents: list[str] = kwargs.get("documents", []) + self.tags: list[str] = kwargs.get("tags", []) + self.links: list[str] = kwargs.get("links", []) - self.postings: list[TXNPosting] = kwargs.get('postings', []) + self.postings: list[TXNPosting] = kwargs.get("postings", []) def check(self) -> bool: """ @@ -142,10 +154,10 @@ class Transaction: header = "{date} {flag} {payee} {narration} {tags} {links}".format( date=self.date, flag=self.flag.value, - payee=f"\"{self.payee or ''}\"", - narration=f"\"{self.narration or ''}\"", - tags=' '.join('#' + tag for tag in self.tags), - links=' '.join('^' + link for link in self.links), + payee=f'"{self.payee or ""}"', + narration=f'"{self.narration or ""}"', + tags=" ".join("#" + tag for tag in self.tags), + links=" ".join("^" + link for link in self.links), ).strip() lines = [] @@ -158,7 +170,4 @@ class Transaction: for posting in self.postings: lines.append(str(posting)) - return '\n'.join((header, *(' ' + line for line in lines))) - - - + return "\n".join((header, *(" " + line for line in lines))) diff --git a/src/beanify/converters/__init__.py b/src/beanify/converters/__init__.py index d50875a..f9a9f1e 100644 --- a/src/beanify/converters/__init__.py +++ b/src/beanify/converters/__init__.py @@ -4,13 +4,20 @@ from base.converter import Converter converters_available: list[Type[Converter]] = [] -def converter_factory(name: str | None = None, qual_name: str | None = None) -> Type[Converter]: + +def converter_factory( + name: str | None = None, qual_name: str | None = None +) -> Type[Converter]: if name and not qual_name: - converter = next((c for c in converters_available if c.converter_name == name), None) + converter = next( + (c for c in converters_available if c.converter_name == name), None + ) if converter is None: raise ValueError(f"No converter matching {name=}") elif qual_name and not name: - converter = next((c for c in converters_available if c.qual_name() == qual_name), None) + converter = next( + (c for c in converters_available if c.qual_name() == qual_name), None + ) if converter is None: raise ValueError(f"No converter matching {qual_name=}") else: @@ -18,8 +25,11 @@ def converter_factory(name: str | None = None, qual_name: str | None = None) -> return converter + def available_converter(converter_cls): converters_available.append(converter_cls) return converter_cls + from .wise_converter import * +from .cba_converter import * diff --git a/src/beanify/converters/cba_converter.py b/src/beanify/converters/cba_converter.py index e69de29..2a44080 100644 --- a/src/beanify/converters/cba_converter.py +++ b/src/beanify/converters/cba_converter.py @@ -0,0 +1,198 @@ +import csv +from dataclasses import dataclass +import datetime as dt +from datetime import datetime +import logging +from typing import NamedTuple +from enum import Enum + +from base import Converter, PartialTXN, PartialPosting, Record, Amount +from base.converter import ConverterConfig +from base.rules import Rule, RuleSet +from base.transaction import TXNFlag + +from . import available_converter + +__all__ = [ + "CBARecord", + "CBAConfig", + "CBAConverter", +] + +logger = logging.getLogger(__name__) + + +class CBACSVRow(NamedTuple): + date: str + amount: str + description: str + unknown: str + + +currency_table = {"US DOLLAR": "USD"} + + +class RecordDirection(Enum): + OUT = "OUT" + IN = "IN" + + +@dataclass(kw_only=True, frozen=True) +class CBARecord(Record): + description: str + direction: RecordDirection + + _display_fields = [ + ("date", "Date"), + ("source_account", "Source Account"), + ("target_account", "Target Account"), + ("from_source", "From Source"), + ("to_target", "To Target"), + ("description", "Description"), + ] + + _match_fields = ["source_account", "target_account", "description"] + + @classmethod + def sample_record(cls): + self = cls( + date=dt.date.today(), + source_account="John Doe", + target_account="Jane Austen", + from_source=Amount(314, "USD"), + to_target=Amount(314, "USD"), + fees=(), + raw="Raw Data", + description="Raw Description", + direction=RecordDirection.OUT, + ) + return self + + +@dataclass +class CBAConfig(ConverterConfig): + asset_account: str + asset_currency: str + + required = { + "asset_account", + "asset_currency", + } + + @classmethod + def from_dict(cls, data: dict) -> "CBAConfig": + if (f := next((f for f in cls.required if f not in data), None)) is not None: + raise ValueError( + f"CBA CSV Converter Configuration missing required field: {f}" + ) + return cls(**data) + + +@available_converter +class CBAConverter(Converter[CBARecord, CBAConfig]): + record_type = CBARecord + config_type = CBAConfig + converter_name = "cbacsv" + version = "0" + display_name = "CBACSV converter v0" + config_field = "CBACSV" + + def convert(self, record: CBARecord, ruleset: RuleSet) -> PartialTXN: + fields = {} + + match record.direction: + case RecordDirection.OUT: + fields["source_account"] = self.config.asset_account + case RecordDirection.IN: + fields["target_account"] = self.config.asset_account + + fields |= ruleset.apply(record.match_fields()) + + args = {} + args["date"] = record.date + if "flag" in fields: + args["flag"] = TXNFlag(fields["flag"]) + + for name in {"payee", "narration", "comment", "document", "tags", "links"}: + if name in fields: + args[name] = fields[name] + + args["source_posting"] = PartialPosting( + account=fields.get("source_account", None), + amount=record.from_source, + total_cost=abs(record.to_target) + if abs(record.from_source) != abs(record.to_target) + else None, + ) + + args["target_posting"] = PartialPosting( + account=fields.get("target_account", None), + amount=record.to_target, + ) + args.setdefault("comment", record.description) + + txn = PartialTXN(**args) + logger.debug(f"Converted CBA CSV Record {record!r} to PartialTXN {txn!r}") + return txn + + def _make_record(self, row: CBACSVRow): + dt_format = "%d/%m/%Y" + + created_on = datetime.strptime(row.date, dt_format).date() + asset_amount_val = float(row.amount.strip('"')) + expense = True if asset_amount_val < 0 else False + asset_amount = Amount(asset_amount_val, self.config.asset_currency) + + other_value = asset_amount_val + bean_curr = self.config.asset_currency + desc = row.description.strip('"') + + # Attempt to handle currency conversions + currency = next((curr for curr in currency_table if desc.endswith(curr)), None) + if currency is not None: + bean_curr = currency_table[currency] + desc, rawamount = desc[: -len(currency)].rsplit(maxsplit=1) + try: + other_value = float(rawamount) + except ValueError: + currency = None + + other_value = abs(other_value) * (1 if expense else -1) + other_amount = Amount(abs(other_value) * (1 if expense else -1), bean_curr) + + if expense: + record = CBARecord( + date=created_on, + source_account="ACCOUNT", + target_account=desc.split(" ")[0].strip(), + from_source=asset_amount, + to_target=other_amount, + raw=",".join(row), + description=desc, + direction=RecordDirection.OUT, + ) + else: + record = CBARecord( + date=created_on, + target_account="ACCOUNT", + source_account=desc.split(" ")[0].strip(), + to_target=asset_amount, + from_source=other_amount, + raw=",".join(row), + description=desc, + direction=RecordDirection.IN, + ) + return record + + def ingest_string(self, data: str) -> list[CBARecord]: + reader = csv.reader(data.splitlines()) + + records = [] + for row in reader: + record = self._make_record(CBACSVRow(*row)) + records.append(record) + return records + + def ingest_file(self, path) -> list[CBARecord]: + with open(path) as f: + return self.ingest_string(f.read()) diff --git a/src/beanify/converters/sample.py b/src/beanify/converters/sample.py new file mode 100644 index 0000000..fdd098f --- /dev/null +++ b/src/beanify/converters/sample.py @@ -0,0 +1,89 @@ +import csv +from dataclasses import dataclass +import datetime as dt +from datetime import datetime +import logging + +from base import Converter, PartialTXN, PartialPosting, Record, Amount +from base.converter import ConverterConfig +from base.rules import Rule, RuleSet +from base.transaction import TXNFlag + +from . import available_converter + +__all__ = [ + "SampleRecord", + "SampleConfig", + "SampleCOnverter", +] + +logger = logging.getLogger(__name__) + + +@dataclass(kw_only=True, frozen=True) +class SampleRecord(Record): + # Add any extra record fields needed here + + _display_fields = [ + ("date", "Date"), + ("source_account", "Source Account"), + ("target_account", "Target Account"), + ("from_source", "From Source"), + ("to_target", "To Target"), + ("raw", "Raw Original"), + ] + + _match_fields = ["source_account", "target_account"] + + @classmethod + def sample_record(cls): + self = cls( + date=dt.date.today(), + source_account="John Doe", + target_account="Jane Austen", + from_source=Amount(314, "USD"), + to_target=Amount(314, "USD"), + fees=(), + raw="Raw Data", + ) + return self + + +@dataclass +class SampleConfig(ConverterConfig): + asset_account: str + + required = { + "asset_account", + } + + @classmethod + def from_dict(cls, data: dict) -> "SampleConfig": + if (f := next((f for f in cls.required if f not in data), None)) is not None: + raise ValueError(f"Wise Configuration missing required field: {f}") + return cls(**data) + + +@available_converter +class SampleConverter(Converter[SampleRecord, SampleConfig]): + record_type = SampleRecord + config_type = SampleConfig + converter_name = "sample" + version = "0" + display_name = "Sample converter v0" + config_field = "SAMPLE" + + def __init__(self, config: SampleConfig, **kwargs): + self.config = config + super().__init__(**kwargs) + + def convert(self, record: SampleRecord, ruleset: RuleSet) -> PartialTXN: + raise NotImplementedError + + @classmethod + def ingest_string(cls, data: str) -> list[SampleRecord]: + raise NotImplementedError + + @classmethod + def ingest_file(cls, path) -> list[SampleRecord]: + raise NotImplementedError diff --git a/src/beanify/converters/wise_converter.py b/src/beanify/converters/wise_converter.py index a2f59fd..c77e463 100644 --- a/src/beanify/converters/wise_converter.py +++ b/src/beanify/converters/wise_converter.py @@ -13,11 +13,11 @@ from base.transaction import TXNFlag from . import available_converter __all__ = [ - 'WiseRecordStatus', - 'WiseRecordDirection', - 'WiseRecord', - 'WiseConfig', - 'WiseConverter', + "WiseRecordStatus", + "WiseRecordDirection", + "WiseRecord", + "WiseConfig", + "WiseConverter", ] @@ -25,13 +25,14 @@ logger = logging.getLogger(__name__) class WiseRecordStatus(Enum): - COMPLETED = 'COMPLETED' - CANCELLED = 'CANCELLED' + COMPLETED = "COMPLETED" + CANCELLED = "CANCELLED" + class WiseRecordDirection(Enum): - OUT = 'OUT' - NEUTRAL = 'NEUTRAL' - IN = 'IN' + OUT = "OUT" + NEUTRAL = "NEUTRAL" + IN = "IN" @dataclass(kw_only=True, frozen=True) @@ -46,26 +47,30 @@ class WiseRecord(Record): exchange_rate: float _display_fields = [ - ('id', "ID"), - ('status', "Status"), - ('direction', "Direction"), - ('created_on', "Created"), - ('finished_on', "Finished"), - ('from_source_net', "Source Net Amount"), - ('source_fee', "Source Fee"), - ('exchange_rate', "Exchange Rate"), - ('exchanged_amount', "Total Exchanged"), - ('target_fee', "Target Fee"), - ('source_currency', "Source Currency"), - ('target_currency', "Target Currency"), - ('source_account', "Source Account"), - ('target_account', "Target Account"), + ("id", "ID"), + ("status", "Status"), + ("direction", "Direction"), + ("created_on", "Created"), + ("finished_on", "Finished"), + ("from_source_net", "Source Net Amount"), + ("source_fee", "Source Fee"), + ("exchange_rate", "Exchange Rate"), + ("exchanged_amount", "Total Exchanged"), + ("target_fee", "Target Fee"), + ("source_currency", "Source Currency"), + ("target_currency", "Target Currency"), + ("source_account", "Source Account"), + ("target_account", "Target Account"), ] _match_fields = [ - 'id', 'status', 'direction', - 'source_currency', 'target_currency', - 'source_account', 'target_account', + "id", + "status", + "direction", + "source_currency", + "target_currency", + "source_account", + "target_account", ] @property @@ -99,14 +104,14 @@ class WiseRecord(Record): amount = amount + self.target_fee return amount - @classmethod + @classmethod def sample_record(cls): self = cls( date=dt.date.today(), source_account="John Doe", target_account="Jane Austen", - from_source=Amount(314, 'GBK'), - to_target=Amount(314, 'KBG'), + from_source=Amount(314, "GBK"), + to_target=Amount(314, "KBG"), fees=tuple(), raw="Raw Data", id="00000", @@ -114,72 +119,72 @@ class WiseRecord(Record): direction=WiseRecordDirection.IN, created_on=datetime.now(), finished_on=datetime.now(), - source_fee=Amount(1, 'GBK'), - target_fee=Amount(1, 'KBG'), + source_fee=Amount(1, "GBK"), + target_fee=Amount(1, "KBG"), exchange_rate=1, ) return self @classmethod def from_row(cls, row): - id = row[0] - status = row[1] - direction = row[2] - created_on = row[3] - finished_on = row[3] - source_fee_amount = float(row[5] or 0) - source_fee_currency = row[6] - target_fee_amount = float(row[7] or 0) - target_fee_currency = row[8] - source_name = row[9] - source_amount_final = float(row[10] or 0) - source_currency = row[11] - target_name = row[12] - target_amount_final = float(row[13] or 0) - target_currency = row[14] - exchange_rate = float(row[15] or 0) - - wise_dt_format = '%Y-%m-%d %H:%M:%S' + id = row[0] + status = row[1] + direction = row[2] + created_on = row[3] + finished_on = row[3] + source_fee_amount = float(row[5] or 0) + source_fee_currency = row[6] + target_fee_amount = float(row[7] or 0) + target_fee_currency = row[8] + source_name = row[9] + source_amount_final = float(row[10] or 0) + source_currency = row[11] + target_name = row[12] + target_amount_final = float(row[13] or 0) + target_currency = row[14] + exchange_rate = float(row[15] or 0) - created_on_dt = datetime.strptime(created_on, wise_dt_format) - finished_on_dt = datetime.strptime(finished_on, wise_dt_format) if finished_on else None - - fees = [] + wise_dt_format = "%Y-%m-%d %H:%M:%S" - if source_fee_amount: - source_fee = Amount(source_fee_amount, source_fee_currency) - fees.append(source_fee) - else: - source_fee = None - if target_fee_amount: - target_fee = Amount(target_fee_amount, target_fee_currency) - fees.append(target_fee) - else: - target_fee = None + created_on_dt = datetime.strptime(created_on, wise_dt_format) + finished_on_dt = ( + datetime.strptime(finished_on, wise_dt_format) if finished_on else None + ) - raw = ','.join(row) + fees = [] - self = cls( - date=created_on_dt.date(), - source_account=source_name, - target_account=target_name, - from_source=Amount(source_amount_final, source_currency), - to_target=Amount(target_amount_final, target_currency), - fees=tuple(fees), - raw=raw, - id=id, - status=WiseRecordStatus(status), - direction=WiseRecordDirection(direction), - created_on=created_on_dt, - finished_on=finished_on_dt, - source_fee=source_fee, - target_fee=target_fee, - exchange_rate=float(exchange_rate), - ) - logger.debug( - f"Converted Wise row {raw} to record {self!r}" - ) - return self + if source_fee_amount: + source_fee = Amount(source_fee_amount, source_fee_currency) + fees.append(source_fee) + else: + source_fee = None + if target_fee_amount: + target_fee = Amount(target_fee_amount, target_fee_currency) + fees.append(target_fee) + else: + target_fee = None + + raw = ",".join(row) + + self = cls( + date=created_on_dt.date(), + source_account=source_name, + target_account=target_name, + from_source=Amount(source_amount_final, source_currency), + to_target=Amount(target_amount_final, target_currency), + fees=tuple(fees), + raw=raw, + id=id, + status=WiseRecordStatus(status), + direction=WiseRecordDirection(direction), + created_on=created_on_dt, + finished_on=finished_on_dt, + source_fee=source_fee, + target_fee=target_fee, + exchange_rate=float(exchange_rate), + ) + logger.debug(f"Converted Wise row {raw} to record {self!r}") + return self @dataclass @@ -191,34 +196,35 @@ class WiseConfig(ConverterConfig): TODO: Actual field defaults for fill-in """ + asset_account: str fee_account: str required = { - 'asset_account', - 'fee_account', + "asset_account", + "fee_account", } - @classmethod - def from_dict(cls, data: dict) -> 'WiseConfig': + @classmethod + def from_dict(cls, data: dict) -> "WiseConfig": if (f := next((f for f in cls.required if f not in data), None)) is not None: raise ValueError(f"Wise Configuration missing required field: {f}") return cls(**data) + @available_converter class WiseConverter(Converter[WiseRecord, WiseConfig]): record_type = WiseRecord config_type = WiseConfig - converter_name = 'wise' - version = '0' + converter_name = "wise" + version = "0" display_name = "Wise Record Converter v0" - config_field = 'WISE' + config_field = "WISE" def __init__(self, config: WiseConfig, **kwargs): self.config = config - def annotation(self, record: WiseRecord, partial: PartialTXN): - ... + def annotation(self, record: WiseRecord, partial: PartialTXN): ... def convert(self, record: WiseRecord, ruleset: RuleSet) -> PartialTXN: fields = {} @@ -226,64 +232,77 @@ class WiseConverter(Converter[WiseRecord, WiseConfig]): match record.direction: # Handle configured default accounts case WiseRecordDirection.OUT: - fields['source_account'] = self.config.asset_account.format(currency=record.source_currency) + fields["source_account"] = self.config.asset_account.format( + currency=record.source_currency + ) case WiseRecordDirection.NEUTRAL: - fields['source_account'] = self.config.asset_account.format(currency=record.source_currency) - fields['target_account'] = self.config.asset_account.format(currency=record.target_currency) + fields["source_account"] = self.config.asset_account.format( + currency=record.source_currency + ) + fields["target_account"] = self.config.asset_account.format( + currency=record.target_currency + ) case WiseRecordDirection.IN: - fields['target_account'] = self.config.asset_account.format(currency=record.target_currency) - + fields["target_account"] = self.config.asset_account.format( + currency=record.target_currency + ) + fields |= ruleset.apply(record.match_fields()) args = {} - args['date'] = record.date + args["date"] = record.date # Convert string flag if it exists - if 'flag' in fields: - args['flag'] = TXNFlag(fields['flag']) + if "flag" in fields: + args["flag"] = TXNFlag(fields["flag"]) # Copy string fields over directly - for name in {'payee', 'narration', 'comment', 'document', 'tags', 'links'}: + for name in {"payee", "narration", "comment", "document", "tags", "links"}: if name in fields: args[name] = fields[name] - args['source_posting'] = PartialPosting( - account=fields.get('source_account', None), + args["source_posting"] = PartialPosting( + account=fields.get("source_account", None), amount=-record.from_source_net, - total_cost=record.exchanged_amount if record.source_currency != record.target_currency else None, + total_cost=record.exchanged_amount + if record.source_currency != record.target_currency + else None, ) - args['target_posting'] = PartialPosting( - account=fields.get('target_account', None), + args["target_posting"] = PartialPosting( + account=fields.get("target_account", None), amount=record.to_target, ) if record.source_fee: - args['source_fee_asset_posting'] = PartialPosting( - account=fields.get('source_fee_asset_account', fields.get('source_account', None)), - amount=-record.source_fee + args["source_fee_asset_posting"] = PartialPosting( + account=fields.get( + "source_fee_asset_account", fields.get("source_account", None) + ), + amount=-record.source_fee, ) - args['source_fee_expense_posting'] = PartialPosting( - account=fields.get('source_fee_expense_account', self.config.fee_account), - amount=record.source_fee + args["source_fee_expense_posting"] = PartialPosting( + account=fields.get( + "source_fee_expense_account", self.config.fee_account + ), + amount=record.source_fee, ) if record.target_fee: - args['target_fee_expense_posting'] = PartialPosting( - account=fields.get('target_fee_expense_account', self.config.fee_account), - amount=record.target_fee + args["target_fee_expense_posting"] = PartialPosting( + account=fields.get( + "target_fee_expense_account", self.config.fee_account + ), + amount=record.target_fee, ) txn = PartialTXN(**args) - logger.debug( - f"Converted Wise Record {record!r} to Partial Transaction {txn!r}" - ) + logger.debug(f"Converted Wise Record {record!r} to Partial Transaction {txn!r}") return txn - @classmethod - def ingest_string(cls, data: str) -> list[WiseRecord]: + def ingest_string(self, data: str) -> list[WiseRecord]: """ Parse a string of raw input into a list of records. """ @@ -296,12 +315,10 @@ class WiseConverter(Converter[WiseRecord, WiseConfig]): logging.info(f"Skipping record with non-complete status: {record}") else: records.append(record) - + return records - @classmethod - def ingest_file(cls, path) -> list[WiseRecord]: + def ingest_file(self, path) -> list[WiseRecord]: with open(path) as f: f.readline() - return cls.ingest_string(f.read()) - + return self.ingest_string(f.read()) diff --git a/src/beanify/gui/mainwindow.py b/src/beanify/gui/mainwindow.py index 2cd7f08..6807d7f 100644 --- a/src/beanify/gui/mainwindow.py +++ b/src/beanify/gui/mainwindow.py @@ -16,7 +16,14 @@ from .rowtree import RowTree class MainWindow(ThemedTk): - def __init__(self, beanconfig, converter: Converter, ruleset: RuleSet, initial_files=[], **kwargs): + def __init__( + self, + beanconfig, + converter: Converter, + ruleset: RuleSet, + initial_files=[], + **kwargs, + ): super().__init__(**kwargs) self.beanconfig = beanconfig @@ -40,22 +47,23 @@ class MainWindow(ThemedTk): self.initial_ingest() def load_styles(self): - self.tk.eval(""" - set base_theme_dir ../themes/awthemes-10.4.0/ - package ifneeded awthemes 10.4.0 \ - [list source [file join $base_theme_dir awthemes.tcl]] - package ifneeded colorutils 4.8 \ - [list source [file join $base_theme_dir colorutils.tcl]] - package ifneeded awdark 7.12 \ - [list source [file join $base_theme_dir awdark.tcl]] - """) - self.tk.call("package", "require", "awdark") + # self.tk.eval(""" + # set base_theme_dir ../themes/awthemes-10.4.0/ + # package ifneeded awthemes 10.4.0 \ + # [list source [file join $base_theme_dir awthemes.tcl]] + # package ifneeded colorutils 4.8 \ + # [list source [file join $base_theme_dir colorutils.tcl]] + # package ifneeded awdark 7.12 \ + # [list source [file join $base_theme_dir awdark.tcl]] + # """) + # self.tk.call("package", "require", "awdark") + # style = ttk.Style(self) + # style.theme_use('awdark') style = ttk.Style(self) - style.theme_use('awdark') def setup_menu(self): self.menubar = tk.Menu(self) - self['menu'] = self.menubar + self["menu"] = self.menubar menu_file = tk.Menu(self.menubar, tearoff=0) menu_file.add_command(label="Ingest File", command=self.do_ingest_file) @@ -69,7 +77,9 @@ class MainWindow(ThemedTk): menu_edit = tk.Menu(self.menubar, tearoff=0) menu_edit.add_command(label="Edit Rules", command=self.do_edit_rules) - menu_edit.add_command(label="Edit Preferences", command=self.do_edit_preferences) + menu_edit.add_command( + label="Edit Preferences", command=self.do_edit_preferences + ) self.menubar.add_cascade(menu=menu_edit, label="Edit") @@ -81,32 +91,40 @@ class MainWindow(ThemedTk): # sashthickness=2 # ) - self.contentframe = ttk.Frame(self, padding=(3, 3, 6, 6), border=1, relief="ridge") + self.contentframe = ttk.Frame( + self, padding=(3, 3, 6, 6), border=1, relief="ridge" + ) self.content = ttk.PanedWindow( self.contentframe, - orient='horizontal', + orient="horizontal", # style='Custom.TPanedwindow', ) - - self.rowtree = RowTree(self, base_record=self.sample_record, padding=(3, 3, 12, 12)) + + self.rowtree = RowTree( + self, base_record=self.sample_record, padding=(3, 3, 12, 12) + ) self.content.add(self.rowtree, weight=1) - self.editor = RowEditor(self, acmpl_cache=self.account_cache, padding=(3, 3, 12, 12)) + self.editor = RowEditor( + self, acmpl_cache=self.account_cache, padding=(3, 3, 12, 12) + ) self.content.add(self.editor, weight=1) self.statusbar = tk.Frame(self, relief=tk.SUNKEN) self.status_var_left = StringVar() self.status_var_left.set("Loading...") - self.status_label_left = tk.Label(self.statusbar, textvariable=self.status_var_left) + self.status_label_left = tk.Label( + self.statusbar, textvariable=self.status_var_left + ) self.columnconfigure(0, weight=1) self.rowconfigure(0, weight=1) - self.contentframe.grid(column=0, row=0, sticky='NSEW') - self.content.grid(column=0, row=0, sticky='NSEW') + self.contentframe.grid(column=0, row=0, sticky="NSEW") + self.content.grid(column=0, row=0, sticky="NSEW") - self.statusbar.grid(row=1, column=0, sticky='ESW') - self.status_label_left.grid(row=0, column=0, sticky='E') + self.statusbar.grid(row=1, column=0, sticky="ESW") + self.status_label_left.grid(row=0, column=0, sticky="E") self.rowconfigure(0, weight=1) self.columnconfigure(0, weight=1) @@ -120,11 +138,11 @@ class MainWindow(ThemedTk): # Better to have message passing up the chain? # i.e. the frame passes up the select - self.rowtree.tree.bind('<>', self.row_selected) - self.editor.bind('<>', self.row_updated) - self.editor.bind('<>', self.rule_created) + self.rowtree.tree.bind("<>", self.row_selected) + self.editor.bind("<>", self.row_updated) + self.editor.bind("<>", self.rule_created) - def update_status(self, message=''): + def update_status(self, message=""): self.status_var_left.set(message) # TODO Add number of incomplete txns? # Add record count? @@ -164,7 +182,7 @@ class MainWindow(ThemedTk): txn = self.converter.convert(record, self.ruleset) self.rows[record] = txn - # Tell the table to regenerate + # Tell the table to regenerate self.rowtree.update_rows(self.rows) self.rebuild_account_cache() @@ -182,10 +200,10 @@ class MainWindow(ThemedTk): self.rebuild_account_cache() def do_ingest_file(self): - # Prompt for file to ingest + # Prompt for file to ingest files = filedialog.askopenfilenames( defaultextension=".csv", - filetypes=[("CSV Files", ".csv"), ("All Files", "*.*")] + filetypes=[("CSV Files", ".csv"), ("All Files", "*.*")], ) rows = {} for file in files: @@ -203,8 +221,8 @@ class MainWindow(ThemedTk): # TODO: Feedback and confirmation def do_export_txn(self): - # TODO: Export options - # TODO: Replace fields with defaults + # TODO: Export options + # TODO: Replace fields with defaults upgraded = [] for partialtxn in self.rows.values(): if partialtxn.partial: @@ -220,13 +238,13 @@ class MainWindow(ThemedTk): filetypes=[ ("Beancount Ledger", ".ledger"), ("All Files", "*.*"), - ] + ], ) if path: - with open(path, 'w') as f: + with open(path, "w") as f: for txn in upgraded: f.write(str(txn)) - f.write('\n\n') + f.write("\n\n") message = f"Exported {len(upgraded)} transactions to {path}" else: message = "Export cancelled, no transactions exported" @@ -256,7 +274,9 @@ class MainWindow(ThemedTk): as Record -> PartialTXN associations. """ records = self.converter.ingest_file(path) - rows = {record: self.converter.convert(record, self.ruleset) for record in records} + rows = { + record: self.converter.convert(record, self.ruleset) for record in records + } return rows def rebuild_account_cache(self): @@ -265,9 +285,9 @@ class MainWindow(ThemedTk): i.e. the map of tx.field -> list[options] used for acmpl on entry. """ - # Get all the account field names + # Get all the account field names # Grab the value of each of these for all the rows we have - # Grab the value of each of these for all the rules in the ruleset + # Grab the value of each of these for all the rules in the ruleset # Merge into a map, and update the cached map with it. # Build the list of account names we want to acmpl @@ -290,8 +310,4 @@ class MainWindow(ThemedTk): self.account_cache.clear() self.account_cache |= cache - - def do_reapply_rules(self): - ... - - + def do_reapply_rules(self): ...