-
-
Notifications
You must be signed in to change notification settings - Fork 272
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add threshold_cutoff * fix docs * add percentile_cutoff
- Loading branch information
Showing
7 changed files
with
95 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import List, Tuple | ||
|
||
import pandas as pd | ||
|
||
from autorag.nodes.passagefilter.base import passage_filter_node | ||
from autorag.utils.util import sort_by_scores, select_top_k | ||
|
||
|
||
@passage_filter_node | ||
def percentile_cutoff(queries: List[str], contents_list: List[List[str]], | ||
scores_list: List[List[float]], ids_list: List[List[str]], | ||
percentile: float, reverse: bool = False, | ||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]: | ||
""" | ||
Filter out the contents that are below the content's length times percentile. | ||
If This is a filter and does not override scores. | ||
If the value of content's length times percentile is less than 1, keep the only one highest similarity content. | ||
:param queries: The list of queries to use for filtering | ||
:param contents_list: The list of lists of contents to filter | ||
:param scores_list: The list of lists of scores retrieved | ||
:param ids_list: The list of lists of ids retrieved | ||
:param percentile: The percentile to cut off | ||
:param reverse: If True, the lower the score, the better | ||
Default is False. | ||
:return: Tuple of lists containing the filtered contents, ids, and scores | ||
""" | ||
num_top_k = max(1, int(len(scores_list[0]) * percentile)) | ||
|
||
df = pd.DataFrame({ | ||
'contents': contents_list, | ||
'ids': ids_list, | ||
'scores': scores_list, | ||
}) | ||
|
||
reverse = not reverse | ||
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand', reverse=reverse) | ||
results = select_top_k(df, ['contents', 'ids', 'scores'], num_top_k) | ||
|
||
return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Percentile Cutoff | ||
|
||
This module is inspired by | ||
our [similarity percentile cutoff](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_filter/similarity_percentile_cutoff.html) | ||
module. | ||
|
||
Filter out the contents that are below the content's length times percentile. | ||
|
||
## **Module Parameters** | ||
|
||
- **percentile** : The percentile value to filter out the contents. | ||
This is essential to run the module, so you have to set this parameter. | ||
- **reverse** : If True, the lower the score, the better. | ||
Default is False. | ||
|
||
## **Example config.yaml** | ||
|
||
```yaml | ||
modules: | ||
- module_type: percentile_cutoff | ||
percentile: 0.6 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 28 additions & 0 deletions
28
tests/autorag/nodes/passagefilter/test_percentile_cutoff.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from autorag.nodes.passagefilter import percentile_cutoff | ||
from tests.autorag.nodes.passagefilter.test_passage_filter_base import queries_example, contents_example, \ | ||
scores_example, ids_example, base_passage_filter_test, project_dir, previous_result, base_passage_filter_node_test | ||
|
||
|
||
def test_percentile_cutoff(): | ||
original_cutoff = percentile_cutoff.__wrapped__ | ||
contents, ids, scores = original_cutoff( | ||
queries_example, contents_example, scores_example, ids_example, percentile=0.6) | ||
base_passage_filter_test(contents, ids, scores) | ||
assert scores[0] == [0.8, 0.5] | ||
assert contents[0] == ["Paris is the capital of France.", | ||
"Paris is one of the capital from France. Isn't it?"] | ||
|
||
|
||
def test_percentile_cutoff_reverse(): | ||
original_cutoff = percentile_cutoff.__wrapped__ | ||
contents, ids, scores = original_cutoff( | ||
queries_example, contents_example, scores_example, ids_example, percentile=0.6, reverse=True) | ||
base_passage_filter_test(contents, ids, scores) | ||
assert scores[0] == [0.1, 0.1] | ||
assert contents[0] == ["NomaDamas is Great Team", "havertz is suck at soccer"] | ||
|
||
|
||
def test_percentile_cutoff_node(): | ||
result_df = percentile_cutoff( | ||
project_dir=project_dir, previous_result=previous_result, percentile=0.9) | ||
base_passage_filter_node_test(result_df) |