Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept payee prediction #239

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions beancount_import/reconcile.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ def _maybe_train_classifier(self):
self.classifier = nltk.classify.scikitlearn.SklearnClassifier(
estimator=sklearn.tree.DecisionTreeClassifier())
self.classifier.train(training_examples)
self.payee_classifier = nltk.classify.scikitlearn.SklearnClassifier(
estimator=sklearn.tree.DecisionTreeClassifier())
self.payee_classifier.train(self.training_examples.payee_examples)
self.reconciler.log_status(
'Trained classifier with %d examples.' % len(training_examples))
classifier_cache_path = self.reconciler.options['classifier_cache']
Expand Down Expand Up @@ -704,6 +707,19 @@ def predict_account(
print('predicted account = %r' % (predicted_account, ))
return predicted_account

def predict_payee(
self, prediction_input: Optional[training.PredictionInput]) -> str:
if self.payee_classifier is None or prediction_input is None:
return None
features = training.get_features(prediction_input)
explanation = get_prediction_explanation(self.payee_classifier, features)
predicted_payee = self.payee_classifier.classify(features)
if display_prediction_explanation:
print('\n'.join(explanation))
print('predicted payee = %r' % (predicted_payee, ))
print('from features', features.keys())
return predicted_payee

def _get_generic_stage(self, entries: Entries):
stage = self.editor.stage_changes()
for entry in entries:
Expand Down Expand Up @@ -739,10 +755,22 @@ def _get_unknown_account_predictions(self,
group_predictions[group_number] for group_number in group_numbers
]

def _get_unknown_payee_prediction(self,
transaction: Transaction) -> List[str]:
group_prediction_inputs = self._feature_extractor.extract_unknown_account_group_features(
transaction)
group_predictions = [
self.predict_payee(prediction_input)
for prediction_input in group_prediction_inputs
]
print("predict:", group_predictions)
return group_predictions[0]

def _make_candidate_with_substitutions(self,
transaction: Transaction,
used_transactions: List[Transaction],
predicted_accounts: List[str],
predicted_payee: Optional[str],
changes: dict = {}):
assert isinstance(changes, dict)
new_accounts = changes.get('accounts')
Expand Down Expand Up @@ -777,9 +805,14 @@ def substitute(changes: dict):
transaction,
used_transactions,
changes=changes,
predicted_accounts=predicted_accounts)
predicted_accounts=predicted_accounts,
predicted_payee=predicted_payee)

new_transaction = _replace_transaction_properties(transaction, changes)

if predicted_payee:
new_transaction = new_transaction._replace(payee=predicted_payee)

real_transaction = _get_transaction_with_substitutions(
new_transaction, new_accounts)
transaction_with_unique_account_names = _get_transaction_with_substitutions(
Expand Down Expand Up @@ -834,11 +867,13 @@ def _make_candidates_from_import_result(self, next_pending):
for transaction, used_transactions in match_results:
predicted_accounts = self._get_unknown_account_predictions(
transaction)
predicted_payee = self._get_unknown_payee_prediction(transaction)
candidates.append(
self._make_candidate_with_substitutions(
transaction,
used_transactions,
predicted_accounts=predicted_accounts))
predicted_accounts=predicted_accounts,
predicted_payee=predicted_payee))
result = Candidates(
candidates=candidates,
date=next_entry.date,
Expand Down
8 changes: 6 additions & 2 deletions beancount_import/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def get_features(example: PredictionInput) -> Dict[str, bool]:
class TrainingExamples(object):
def __init__(self):
self.training_examples = []
self.payee_examples = []

def add(self, example: PredictionInput, target_account: str):
def add(self, example: PredictionInput, target_account: str, target_payee: str):
self.payee_examples.append((get_features(example), target_payee))
self.training_examples.append((get_features(example), target_account))


Expand Down Expand Up @@ -182,6 +184,7 @@ def extract_examples(self, entries: Entries,
date=entry.date,
key_value_pairs=key_value_pairs),
target_account=posting.account,
target_payee=entry.payee,
)
if got_example: continue

Expand All @@ -208,7 +211,8 @@ def extract_examples(self, entries: Entries,
key_value_pairs=key_value_pairs,
date=get_posting_date(entry, posting),
amount=posting.units),
target_account=target_account)
target_account=target_account,
target_payee=entry.payee)

def extract_unknown_account_group_features(
self, transaction: Transaction) -> List[Optional[PredictionInput]]:
Expand Down