From cb90c0bf91faea0f76f376e2a02e80c1842ad6ca Mon Sep 17 00:00:00 2001 From: Ryan Souza Date: Wed, 18 Sep 2024 17:54:04 -0700 Subject: [PATCH] Proof of concept payee prediction --- beancount_import/reconcile.py | 39 +++++++++++++++++++++++++++++++++-- beancount_import/training.py | 8 +++++-- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/beancount_import/reconcile.py b/beancount_import/reconcile.py index 18d9b20f..0218f771 100644 --- a/beancount_import/reconcile.py +++ b/beancount_import/reconcile.py @@ -475,6 +475,9 @@ def _maybe_train_classifier(self): self.classifier = nltk.classify.scikitlearn.SklearnClassifier( estimator=sklearn.tree.DecisionTreeClassifier()) self.classifier.train(training_examples) + self.payee_classifier = nltk.classify.scikitlearn.SklearnClassifier( + estimator=sklearn.tree.DecisionTreeClassifier()) + self.payee_classifier.train(self.training_examples.payee_examples) self.reconciler.log_status( 'Trained classifier with %d examples.' % len(training_examples)) classifier_cache_path = self.reconciler.options['classifier_cache'] @@ -704,6 +707,19 @@ def predict_account( print('predicted account = %r' % (predicted_account, )) return predicted_account + def predict_payee( + self, prediction_input: Optional[training.PredictionInput]) -> str: + if self.payee_classifier is None or prediction_input is None: + return None + features = training.get_features(prediction_input) + explanation = get_prediction_explanation(self.payee_classifier, features) + predicted_payee = self.payee_classifier.classify(features) + if display_prediction_explanation: + print('\n'.join(explanation)) + print('predicted payee = %r' % (predicted_payee, )) + print('from features', features.keys()) + return predicted_payee + def _get_generic_stage(self, entries: Entries): stage = self.editor.stage_changes() for entry in entries: @@ -739,10 +755,22 @@ def _get_unknown_account_predictions(self, group_predictions[group_number] for group_number in group_numbers ] + def _get_unknown_payee_prediction(self, + transaction: Transaction) -> List[str]: + group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features( + transaction) + group_predictions = [ + self.predict_payee(prediction_input) + for prediction_input in group_prediction_inputs + ] + print("predict:", group_predictions) + return group_predictions[0] + def _make_candidate_with_substitutions(self, transaction: Transaction, used_transactions: List[Transaction], predicted_accounts: List[str], + predicted_payee: Optional[str], changes: dict = {}): assert isinstance(changes, dict) new_accounts = changes.get('accounts') @@ -777,9 +805,14 @@ def substitute(changes: dict): transaction, used_transactions, changes=changes, - predicted_accounts=predicted_accounts) + predicted_accounts=predicted_accounts, + predicted_payee=predicted_payee) new_transaction = _replace_transaction_properties(transaction, changes) + + if predicted_payee: + new_transaction = new_transaction._replace(payee=predicted_payee) + real_transaction = _get_transaction_with_substitutions( new_transaction, new_accounts) transaction_with_unique_account_names = _get_transaction_with_substitutions( @@ -834,11 +867,13 @@ def _make_candidates_from_import_result(self, next_pending): for transaction, used_transactions in match_results: predicted_accounts = self._get_unknown_account_predictions( transaction) + predicted_payee = self._get_unknown_payee_prediction(transaction) candidates.append( self._make_candidate_with_substitutions( transaction, used_transactions, - predicted_accounts=predicted_accounts)) + predicted_accounts=predicted_accounts, + predicted_payee=predicted_payee)) result = Candidates( candidates=candidates, date=next_entry.date, diff --git a/beancount_import/training.py b/beancount_import/training.py index 30f74173..ac078db4 100644 --- a/beancount_import/training.py +++ b/beancount_import/training.py @@ -52,8 +52,10 @@ def get_features(example: PredictionInput) -> Dict[str, bool]: class TrainingExamples(object): def __init__(self): self.training_examples = [] + self.payee_examples = [] - def add(self, example: PredictionInput, target_account: str): + def add(self, example: PredictionInput, target_account: str, target_payee: str): + self.payee_examples.append((get_features(example), target_payee)) self.training_examples.append((get_features(example), target_account)) @@ -182,6 +184,7 @@ def extract_examples(self, entries: Entries, date=entry.date, key_value_pairs=key_value_pairs), target_account=posting.account, + target_payee=entry.payee, ) if got_example: continue @@ -208,7 +211,8 @@ def extract_examples(self, entries: Entries, key_value_pairs=key_value_pairs, date=get_posting_date(entry, posting), amount=posting.units), - target_account=target_account) + target_account=target_account, + target_payee=entry.payee) def extract_unknown_account_group_features( self, transaction: Transaction) -> List[Optional[PredictionInput]]: