diff --git a/src/ficamp/__main__.py b/src/ficamp/__main__.py index a7f38ff..9487dda 100644 --- a/src/ficamp/__main__.py +++ b/src/ficamp/__main__.py @@ -1,20 +1,13 @@ import argparse -import json -import os -import shutil -from enum import StrEnum +from collections import defaultdict import questionary from dotenv import load_dotenv from sqlmodel import Session, SQLModel, create_engine, select -from ficamp.classifier.infer import infer_tx_category +from ficamp.classifier.keywords import sort_by_keyword_matches from ficamp.classifier.preprocessing import preprocess from ficamp.datastructures import Tx -from ficamp.parsers.abn import AbnParser -from ficamp.parsers.bbva import AccountBBVAParser, CreditCardBBVAParser -from ficamp.parsers.bsabadell import AccountBSabadellParser, CreditCardBSabadellParser -from ficamp.parsers.caixabank import CaixaBankParser from ficamp.parsers.enums import BankParser @@ -37,14 +30,13 @@ def cli() -> argparse.Namespace: default="abn", help="Specify the bank for the import", ) - import_parser.add_argument("filename", help="File to load") + import_parser.add_argument("--filename", help="File to load") import_parser.set_defaults(func=import_data) # Subparser for the categorize command categorize_parser = subparsers.add_parser( "categorize", help="Categorize transactions" ) - categorize_parser.add_argument("--infer-category", action="store_true") categorize_parser.set_defaults(func=categorize) args = parser.parse_args() @@ -80,25 +72,34 @@ class DefaultAnswers: NEW = "Type a new category" -def query_business_category(tx, session, infer_category=False): - # first try to get from the category_dict +def make_map_cat_to_kws(session): + statement = select(Tx).where(Tx.category.is_not(None)) + known_cat_tx = session.exec(statement).all() + keywords = defaultdict(list) + for tx in known_cat_tx: + keywords[tx.category].extend(tx.concept_clean.split()) + return keywords + + +def query_business_category(tx, session): + # Clean up the transaction concept string tx.concept_clean = preprocess(tx.concept) + + # If there is an exact match to the known transactions, return that one statement = select(Tx.category).where(Tx.concept_clean == tx.concept_clean) category = session.exec(statement).first() if category: return category - # ask the user if we don't know it - # query each time to update - statement = select(Tx.category).where(Tx.category.is_not(None)).distinct() - categories_choices = session.exec(statement).all() + + # Build map of category --> keywords + cats = make_map_cat_to_kws(session) + cats_sorted_by_matches = sort_by_keyword_matches(cats, tx.concept_clean) + # Show categories to user sorted by keyword criterion + categories_choices = [cat for _, cat in cats_sorted_by_matches] categories_choices.extend([DefaultAnswers.NEW, DefaultAnswers.SKIP]) - default_choice = DefaultAnswers.SKIP - if infer_category: - inferred_category = infer_tx_category(tx) - if inferred_category: - categories_choices.append(inferred_category) - default_choice = inferred_category - print(f"{tx.date.isoformat()} {tx.amount} {tx.concept_clean}") + default_choice = categories_choices[0] + + print(f"{tx.date.isoformat()} | {tx.amount} | {tx.concept_clean}") answer = questionary.select( "Please select the category for this TX", choices=categories_choices, @@ -115,8 +116,8 @@ def query_business_category(tx, session, infer_category=False): return answer -def categorize(args, engine): - """Function to categorize transactions.""" +def categorize(engine): + """Classify transactions into categories""" try: with Session(engine) as session: statement = select(Tx).where(Tx.category.is_(None)) @@ -124,8 +125,7 @@ def categorize(args, engine): print(f"Got {len(results)} Tx to categorize") for tx in results: print(f"Processing {tx}") - tx_category = query_business_category( - tx, session, infer_category=args.infer_category) + tx_category = query_business_category(tx, session) if tx_category: print(f"Saving category for {tx.concept}: {tx_category}") tx.category = tx_category @@ -135,19 +135,16 @@ def categorize(args, engine): else: print("Not saving any category for thi Tx") except KeyboardInterrupt: - print("Closing") + print("Session interrupted. Closing.") def main(): - # create DB - engine = create_engine("sqlite:///ficamp.db") - # create tables - SQLModel.metadata.create_all(engine) - + engine = create_engine("sqlite:///ficamp.db") # create DB + SQLModel.metadata.create_all(engine) # create tables try: args = cli() if args.command: - args.func(args, engine) + args.func(engine) except KeyboardInterrupt: print("\nClosing") diff --git a/src/ficamp/classifier/encoding.py b/src/ficamp/classifier/encoding.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ficamp/classifier/features.py b/src/ficamp/classifier/features.py deleted file mode 100644 index 0e1e28e..0000000 --- a/src/ficamp/classifier/features.py +++ /dev/null @@ -1,88 +0,0 @@ -from dataclasses import dataclass -from functools import lru_cache -from typing import Any - -import numpy as np -from sklearn.utils import murmurhash3_32 - - -@dataclass -class HashFeature: - cardinality: int - - -def sort_dict(d): - return dict(sorted(d.items(), key=lambda x: x[0])) - - -@lru_cache(500) -def categorical_hash_bucket(item: str, *, buckets: int) -> int: - """Integer representation of item achieved like this: hash(item) % buckets""" - if not isinstance(item, str): - raise ValueError("item must be a string") - return murmurhash3_32(item, positive=True) % buckets - - -# def make_lowercase(d): -# return {"desc": d["desc"].lower()} - - -def has_iban(d) -> str: - """ - Return str(int(True)) if IBAN detected, else str(int(False)). - str because we need string to encode feature""" - # TODO: remove hardcoded value - return d | {"has_iban": str(int(True))} - - -def extract_city(d) -> str: - "Return city name in lowercase if found, else " - # TODO: remove hardcoded value - return d | {"city": "Lleida"} - - -def extract_payment_method(d: dict) -> str | dict[str, Any]: - "Return payment method name in lowercase if found, else " - # TODO:improve logic - - payment_methods = sorted( - ( - "card", - "creditcard", - "paypal", - "transfer", - ) - ) - res = "" - for method in payment_methods: - if method in d["desc"]: - res = method - return d | {"payment_method": res} - - -# This config determines which features will be extracted -# from the transaction description ("desc") -CONFIG = { - "city": HashFeature(cardinality=100), - "has_iban": HashFeature(cardinality=2), - "payment_method": HashFeature(cardinality=10), -} - - -def get_features(preprocessed: dict): - features = {} - for name, conf in CONFIG.items(): - features[name] = categorical_hash_bucket( - preprocessed[name], buckets=conf.cardinality - ) - return features - - -def one_hot_encode(features: dict): - encoded = {} - for name, idx in sort_dict(features).items(): - c = CONFIG[name].cardinality - one_hot = np.zeros(c) - one_hot[idx] = 1 - encoded[name] = one_hot - return encoded diff --git a/src/ficamp/classifier/features_config.py b/src/ficamp/classifier/features_config.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/ficamp/classifier/keywords.py b/src/ficamp/classifier/keywords.py new file mode 100644 index 0000000..30a8049 --- /dev/null +++ b/src/ficamp/classifier/keywords.py @@ -0,0 +1,14 @@ +""" +Logic to sort transactions based on keywords. +""" +import json +import pathlib + + +def sort_by_keyword_matches(categories: dict, description: str) -> list[str]: + description = description.lower() + matches = [] + for category, keywords in categories.items(): + n_matches = sum(keyword in description for keyword in keywords) + matches.append((n_matches, category)) + return sorted(matches, reverse=True) diff --git a/src/ficamp/classifier/payment_method.py b/src/ficamp/classifier/payment_method.py deleted file mode 100644 index 8b13789..0000000 --- a/src/ficamp/classifier/payment_method.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/ficamp/classifier/preprocessing.py b/src/ficamp/classifier/preprocessing.py index 8f13bb5..811ccf5 100644 --- a/src/ficamp/classifier/preprocessing.py +++ b/src/ficamp/classifier/preprocessing.py @@ -1,3 +1,6 @@ +import string + + def remove_digits(s: str) -> str: """ Return string without words that have more that 2 digits. @@ -29,6 +32,27 @@ def remove_comma(s: str) -> str: return " ".join(s.split(",")) +def remove_punctuation(s: str) -> str: + punctuation = set(string.punctuation) + out = "".join((" " if char in punctuation else char for char in s)) + return " ".join(out.split()) # Remove double spaces + + +def remove_isolated_digits(s: str) -> str: + """Remove words made only of digits""" + digits = set(string.digits) + clean = [] + for word in s.split(): + if not all((char in digits for char in word)): + clean.append(word) + return " ".join(clean) + + +def remove_short_words(s: str) -> str: + """Remove words made only of digits""" + return " ".join((word for word in s.split() if len(word) >= 2)) + + def preprocess(s: str) -> str: "Clean up transaction description" steps = ( @@ -37,6 +61,9 @@ def preprocess(s: str) -> str: remove_colon, remove_comma, remove_digits, + remove_punctuation, + remove_isolated_digits, + remove_short_words, ) out = s for func in steps: diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 17e30de..956638d 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -6,6 +6,9 @@ remove_comma, remove_digits, remove_pipes, + remove_punctuation, + remove_isolated_digits, + remove_short_words, ) @@ -56,6 +59,41 @@ def test_remove_comma(inp, exp): assert remove_comma(inp) == exp +@pytest.mark.parametrize( + ("inp,exp"), + ( + ("hello world", "hello world"), + ("hello/world", "hello world"), + ("hello.world", "hello world"), + ("hello.(.world))", "hello world"), + ), +) +def test_remove_punctuation(inp, exp): + assert remove_punctuation(inp) == exp + + +@pytest.mark.parametrize( + ("inp,exp"), + ( + ("hello22 world", "hello22 world"), + ("hello 22 world", "hello world"), + ), +) +def test_remove_isolated_digits(inp, exp): + assert remove_isolated_digits(inp) == exp + + +@pytest.mark.parametrize( + ("inp,exp"), + ( + ("hello a world", "hello world"), + ("hello aa world", "hello aa world"), + ), +) +def test_remove_short_words(inp, exp): + assert remove_short_words(inp) == exp + + @pytest.mark.parametrize( ("inp,exp"), ( @@ -70,6 +108,9 @@ def test_remove_comma(inp, exp): ("SEPA 1231|AMSTERDAM 123BIC", "sepa amsterdam"), ("CSID:NL0213324324324", "csid"), ("CSID:NL0213324324324 HELLO,world1332", "csid hello"), + ("CSID:NL021332432 N26 HELLO,world1332", "csid n26 hello"), + ("CSID:NL021332432 4324 HELLO,world1332", "csid hello"), + ("CSID:NL021332432 n. HELLO,world1332", "csid hello"), ), ) def test_preprocess(inp, exp):