From 2dee0560bc0f0583e1a2db3c0cda633fb834b48d Mon Sep 17 00:00:00 2001 From: "Bwook (Byoungwook) Kim" Date: Tue, 14 May 2024 14:39:13 +0900 Subject: [PATCH] Add percentile cutoff module (#429) * Add threshold_cutoff * fix docs * add percentile_cutoff --- autorag/nodes/passagefilter/__init__.py | 1 + .../nodes/passagefilter/percentile_cutoff.py | 40 +++++++++++++++++++ autorag/support.py | 1 + .../nodes/passage_filter/passage_filter.md | 1 + .../nodes/passage_filter/percentile_cutoff.md | 22 ++++++++++ sample_config/full.yaml | 2 + .../passagefilter/test_percentile_cutoff.py | 28 +++++++++++++ 7 files changed, 95 insertions(+) create mode 100644 autorag/nodes/passagefilter/percentile_cutoff.py create mode 100644 docs/source/nodes/passage_filter/percentile_cutoff.md create mode 100644 tests/autorag/nodes/passagefilter/test_percentile_cutoff.py diff --git a/autorag/nodes/passagefilter/__init__.py b/autorag/nodes/passagefilter/__init__.py index dce8ff9f3..af9993954 100644 --- a/autorag/nodes/passagefilter/__init__.py +++ b/autorag/nodes/passagefilter/__init__.py @@ -1,4 +1,5 @@ from .pass_passage_filter import pass_passage_filter +from .percentile_cutoff import percentile_cutoff from .recency import recency_filter from .similarity_percentile_cutoff import similarity_percentile_cutoff from .similarity_threshold_cutoff import similarity_threshold_cutoff diff --git a/autorag/nodes/passagefilter/percentile_cutoff.py b/autorag/nodes/passagefilter/percentile_cutoff.py new file mode 100644 index 000000000..e9a469364 --- /dev/null +++ b/autorag/nodes/passagefilter/percentile_cutoff.py @@ -0,0 +1,40 @@ +from typing import List, Tuple + +import pandas as pd + +from autorag.nodes.passagefilter.base import passage_filter_node +from autorag.utils.util import sort_by_scores, select_top_k + + +@passage_filter_node +def percentile_cutoff(queries: List[str], contents_list: List[List[str]], + scores_list: List[List[float]], ids_list: List[List[str]], + percentile: float, reverse: bool = False, + ) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: + """ + Filter out the contents that are below the content's length times percentile. + If This is a filter and does not override scores. + If the value of content's length times percentile is less than 1, keep the only one highest similarity content. + + :param queries: The list of queries to use for filtering + :param contents_list: The list of lists of contents to filter + :param scores_list: The list of lists of scores retrieved + :param ids_list: The list of lists of ids retrieved + :param percentile: The percentile to cut off + :param reverse: If True, the lower the score, the better + Default is False. + :return: Tuple of lists containing the filtered contents, ids, and scores + """ + num_top_k = max(1, int(len(scores_list[0]) * percentile)) + + df = pd.DataFrame({ + 'contents': contents_list, + 'ids': ids_list, + 'scores': scores_list, + }) + + reverse = not reverse + df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand', reverse=reverse) + results = select_top_k(df, ['contents', 'ids', 'scores'], num_top_k) + + return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist() diff --git a/autorag/support.py b/autorag/support.py index e552438eb..4f681dbc7 100644 --- a/autorag/support.py +++ b/autorag/support.py @@ -48,6 +48,7 @@ def get_support_modules(module_name: str) -> Callable: 'similarity_percentile_cutoff': ('autorag.nodes.passagefilter', 'similarity_percentile_cutoff'), 'recency_filter': ('autorag.nodes.passagefilter', 'recency_filter'), 'threshold_cutoff': ('autorag.nodes.passagefilter', 'threshold_cutoff'), + 'percentile_cutoff': ('autorag.nodes.passagefilter', 'percentile_cutoff'), # passage_compressor 'tree_summarize': ('autorag.nodes.passagecompressor', 'tree_summarize'), 'pass_compressor': ('autorag.nodes.passagecompressor', 'pass_compressor'), diff --git a/docs/source/nodes/passage_filter/passage_filter.md b/docs/source/nodes/passage_filter/passage_filter.md index d373ee016..5edf060c4 100644 --- a/docs/source/nodes/passage_filter/passage_filter.md +++ b/docs/source/nodes/passage_filter/passage_filter.md @@ -53,4 +53,5 @@ similarity_threshold_cutoff.md similarity_percentile_cutoff.md recency_filter.md threshold_cutoff.md +percentile_cutoff.md ``` diff --git a/docs/source/nodes/passage_filter/percentile_cutoff.md b/docs/source/nodes/passage_filter/percentile_cutoff.md new file mode 100644 index 000000000..e600bddbe --- /dev/null +++ b/docs/source/nodes/passage_filter/percentile_cutoff.md @@ -0,0 +1,22 @@ +# Percentile Cutoff + +This module is inspired by +our [similarity percentile cutoff](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_filter/similarity_percentile_cutoff.html) +module. + +Filter out the contents that are below the content's length times percentile. + +## **Module Parameters** + +- **percentile** : The percentile value to filter out the contents. + This is essential to run the module, so you have to set this parameter. +- **reverse** : If True, the lower the score, the better. + Default is False. + +## **Example config.yaml** + +```yaml +modules: + - module_type: percentile_cutoff + percentile: 0.6 +``` diff --git a/sample_config/full.yaml b/sample_config/full.yaml index cbb7df3c5..bc49c470d 100644 --- a/sample_config/full.yaml +++ b/sample_config/full.yaml @@ -94,6 +94,8 @@ node_lines: threshold: 2015-01-01 - module_type: threshold_cutoff threshold: 0.85 + - module_type: percentile_cutoff + percentile: 0.6 - node_type: passage_compressor strategy: metrics: [retrieval_token_f1, retrieval_token_recall, retrieval_token_precision] diff --git a/tests/autorag/nodes/passagefilter/test_percentile_cutoff.py b/tests/autorag/nodes/passagefilter/test_percentile_cutoff.py new file mode 100644 index 000000000..2e8cb469c --- /dev/null +++ b/tests/autorag/nodes/passagefilter/test_percentile_cutoff.py @@ -0,0 +1,28 @@ +from autorag.nodes.passagefilter import percentile_cutoff +from tests.autorag.nodes.passagefilter.test_passage_filter_base import queries_example, contents_example, \ + scores_example, ids_example, base_passage_filter_test, project_dir, previous_result, base_passage_filter_node_test + + +def test_percentile_cutoff(): + original_cutoff = percentile_cutoff.__wrapped__ + contents, ids, scores = original_cutoff( + queries_example, contents_example, scores_example, ids_example, percentile=0.6) + base_passage_filter_test(contents, ids, scores) + assert scores[0] == [0.8, 0.5] + assert contents[0] == ["Paris is the capital of France.", + "Paris is one of the capital from France. Isn't it?"] + + +def test_percentile_cutoff_reverse(): + original_cutoff = percentile_cutoff.__wrapped__ + contents, ids, scores = original_cutoff( + queries_example, contents_example, scores_example, ids_example, percentile=0.6, reverse=True) + base_passage_filter_test(contents, ids, scores) + assert scores[0] == [0.1, 0.1] + assert contents[0] == ["NomaDamas is Great Team", "havertz is suck at soccer"] + + +def test_percentile_cutoff_node(): + result_df = percentile_cutoff( + project_dir=project_dir, previous_result=previous_result, percentile=0.9) + base_passage_filter_node_test(result_df)