import tkinter as tk from datetime import datetime, date from functools import partial from tkinter import BooleanVar, ttk from base.partial import PartialTXN from base.record import Record from base.transaction import Amount from . import logger def datetime_sort_key(datestr: datetime | str): value = datetime.fromtimestamp(0) if isinstance(datestr, (datetime, date)): value = datestr elif isinstance(datestr, str): if datestr: value = datetime.strptime(datestr, "%Y-%m-%d %H:%M:%S") return value def date_sort_key(datestr: date | str): value = datetime.fromtimestamp(0).date() if isinstance(datestr, date): value = datestr elif isinstance(datestr, str): if datestr: value = datetime.strptime(datestr, "%Y-%m-%d").date() return value def amount_sort_key(amount: Amount | str): if isinstance(amount, Amount): value = (amount.currency, amount.value) elif amount: amountstr, currency = amount.strip().split() value = (currency, float(amountstr)) else: value = ('', 0) return value class SortingTree(ttk.Treeview): """ Treeview with column sorting and column selection """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.bind('', self.handle_rightclick) self.showvars = {col: BooleanVar() for col in self['columns']} for var in self.showvars.values(): # Default columns to on var.set(True) for col in self['columns']: self.column(col, stretch=True) # self.set_columns(self['columns']) @property def heading_columns(self): heading_map = {} for col in self["columns"]: heading_map[self.heading(col)['text']] = col return heading_map def reset_column_sizes(self): displayed = self['displaycolumns'] max_widths = [0] * len(displayed) fontname = ttk.Style(self).lookup('TLabel', "font") # treefont = font.nametofont(fontname) # treefont = font.Font(family="TkDefaultFont") magic = 7 for child in self.get_children(''): for (i, col) in enumerate(displayed): data = self.set(child, col) # length = treefont.measure(str(data)) length = len(str(data)) * magic max_widths[i] = max(max_widths[i], length) for i, col in enumerate(displayed): width = max(max_widths[i], 50) + 50 self.column(col, width=width) def heading(self, column, option=None, *, sort_key=None, **kwargs): if sort_key and not hasattr(kwargs, 'command'): command = partial(self._sort, column, False, sort_key) kwargs['command'] = command return super().heading(column, option, **kwargs) def set_columns(self, columns: list[str]): self['displaycolumns'] = columns for col, var in self.showvars.items(): var.set(col in columns) def _sort(self, column, reverse, key): l = [(self.set(k, column), k) for k in self.get_children('')] l.sort(key=lambda t: key(t[0]), reverse=reverse) for index, (_, k) in enumerate(l): self.move(k, '', index) self.heading(column, command=partial(self._sort, column, not reverse, key)) def handle_rightclick(self, event): logger.debug(f"Received right click on SortingTree: {event!r}") region = self.identify_region(event.x, event.y) logger.debug(f"REGION: {region}") if region == 'heading': self.do_column_select(event) def do_column_select(self, event): # Popup a right click menu at the event location with a list of selectable columns logger.debug("Creating column selection menu for SortingTree") menu = tk.Menu(self, tearoff=1) for heading, column in self.heading_columns.items(): menu.add_checkbutton(variable=self.showvars[column], label=heading, command=self._show_columns) try: menu.tk_popup(event.x_root, event.y_root) finally: menu.grab_release() def _show_columns(self): columns = [col for col, var in self.showvars.items() if var.get()] self.set_columns(columns) self.reset_column_sizes() class RowTree(ttk.Frame): def __init__(self, master, base_record: Record, rows={}, **kwargs): super().__init__(master, **kwargs) # The base record is used as a template for the column display self.base_record = base_record self.make_tree() self.items: dict[str, tuple[Record, PartialTXN]] = {} self.record_items: dict[Record, str] = {} self.update_rows(rows) self.sort_by = 0 self.layout() def generate_columns(self, record: Record,): record_fields = record.display_fields() columns = { 'record.date': ("Date", date_sort_key), 'record.from_source': ("Amount (from source)", amount_sort_key), } for field in record_fields: name = f"record.{field.name}" value = getattr(record, field.name) if isinstance(value, Amount): sorter = amount_sort_key elif isinstance(value, datetime): sorter = datetime_sort_key else: sorter = str columns[name] = (field.display_name, sorter) columns |= { 'txn.flag': ("Bean Status", str), 'txn.payee': ("Bean Payee", str), 'txn.narration': ("Bean Narration", str), 'txn.comment': ("Bean Comment", str), 'txn.document': ("Bean Document", str), 'txn.tags': ("Bean Tags", str), 'txn.links': ("Bean Links", str), 'txn.source_account': ("Bean Source", str), 'txn.source_fee_asset_account': ("Bean Source Fee Asset Acc", str), 'txn.source_fee_expense_account': ("Bean Source Fee Expense Acc", str), 'txn.target_account': ("Bean Target", str), 'txn.target_fee_expense_account': ("Bean Target Fee Acc", str), } return columns def make_tree(self): self.columns = self.generate_columns(self.base_record) initially_enabled = [ 'record.date', 'record.source_account', 'record.target_account', 'record.from_source', 'txn.source_account', 'txn.target_account', ] self.tree = SortingTree( self, columns=tuple(self.columns.keys()), show='headings', ) for col, (dname, sort_key) in self.columns.items(): self.tree.heading(col, text=dname, sort_key=sort_key) self.tree.set_columns(initially_enabled) def layout(self): self.rowconfigure(0, weight=1) self.columnconfigure(0, weight=1) self.tree.grid(row=0, column=0, sticky='NSEW') self.tree.rowconfigure(0, weight=1) self.tree.columnconfigure(0, weight=1) self.scrollyframe = ttk.Frame(self, relief='groove') self.scrolly = ttk.Scrollbar( self, orient='vertical', command=self.tree.yview, ) # self.scrollyframe.grid(row=0, column=1, sticky='NS') self.scrolly.grid(row=0, column=1, sticky='NS') self.tree.configure(yscrollcommand=self.scrolly.set) self.scrollx = ttk.Scrollbar( self, orient='horizontal', command=self.tree.xview, ) def get_selected_row(self) -> tuple[Record, PartialTXN] | None: item = self.tree.selection()[0] if item: return self.items[item] else: return None def wipe(self): self.tree.delete(*self.items.keys()) self.items.clear() logger.debug("Wiped the row tree.") def update_rows(self, rows: dict[Record, PartialTXN]): # self.wipe() added = 0 updated = 0 for record, txn in rows.items(): if record in self.record_items: itemid = self.record_items[record] self.tree.item( itemid, values=self.row_values(record, txn) ) updated += 1 else: itemid = self.tree.insert( parent='', index='end', values=self.row_values(record, txn) ) self.record_items[record] = itemid added += 1 self.items[itemid] = (record, txn) logger.debug(f"Added {added} and updated {updated} rows in the RowTree.") def update_this_row(self, record, txn): item = self.tree.selection()[0] self.tree.item(item, values=self.row_values(record, txn)) def row_values(self, record, txn): values = [] for col in self.columns: match col.split('.', maxsplit=1): case ["txn", "flag"]: value = txn.flag.value case ["txn", field]: value = getattr(txn, field) case ["record", "date"]: value = record.date.strftime('%Y-%m-%d') case ["record", field]: value = getattr(record, field) # print(f"{col=} {value=} {type(value)=}") case _: raise ValueError(f"Unexpected column {col}") value = value if value is not None else '' values.append(value) return values