Skip to content

Commit

Permalink
Add percentile cutoff module (#429)
Browse files Browse the repository at this point in the history
* Add threshold_cutoff

* fix docs

* add percentile_cutoff
  • Loading branch information
bwook00 authored May 14, 2024
1 parent bcfdfb4 commit 2dee056
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 0 deletions.
1 change: 1 addition & 0 deletions autorag/nodes/passagefilter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pass_passage_filter import pass_passage_filter
from .percentile_cutoff import percentile_cutoff
from .recency import recency_filter
from .similarity_percentile_cutoff import similarity_percentile_cutoff
from .similarity_threshold_cutoff import similarity_threshold_cutoff
Expand Down
40 changes: 40 additions & 0 deletions autorag/nodes/passagefilter/percentile_cutoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List, Tuple

import pandas as pd

from autorag.nodes.passagefilter.base import passage_filter_node
from autorag.utils.util import sort_by_scores, select_top_k


@passage_filter_node
def percentile_cutoff(queries: List[str], contents_list: List[List[str]],
scores_list: List[List[float]], ids_list: List[List[str]],
percentile: float, reverse: bool = False,
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
"""
Filter out the contents that are below the content's length times percentile.
If This is a filter and does not override scores.
If the value of content's length times percentile is less than 1, keep the only one highest similarity content.
:param queries: The list of queries to use for filtering
:param contents_list: The list of lists of contents to filter
:param scores_list: The list of lists of scores retrieved
:param ids_list: The list of lists of ids retrieved
:param percentile: The percentile to cut off
:param reverse: If True, the lower the score, the better
Default is False.
:return: Tuple of lists containing the filtered contents, ids, and scores
"""
num_top_k = max(1, int(len(scores_list[0]) * percentile))

df = pd.DataFrame({
'contents': contents_list,
'ids': ids_list,
'scores': scores_list,
})

reverse = not reverse
df[['contents', 'ids', 'scores']] = df.apply(sort_by_scores, axis=1, result_type='expand', reverse=reverse)
results = select_top_k(df, ['contents', 'ids', 'scores'], num_top_k)

return results['contents'].tolist(), results['ids'].tolist(), results['scores'].tolist()
1 change: 1 addition & 0 deletions autorag/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_support_modules(module_name: str) -> Callable:
'similarity_percentile_cutoff': ('autorag.nodes.passagefilter', 'similarity_percentile_cutoff'),
'recency_filter': ('autorag.nodes.passagefilter', 'recency_filter'),
'threshold_cutoff': ('autorag.nodes.passagefilter', 'threshold_cutoff'),
'percentile_cutoff': ('autorag.nodes.passagefilter', 'percentile_cutoff'),
# passage_compressor
'tree_summarize': ('autorag.nodes.passagecompressor', 'tree_summarize'),
'pass_compressor': ('autorag.nodes.passagecompressor', 'pass_compressor'),
Expand Down
1 change: 1 addition & 0 deletions docs/source/nodes/passage_filter/passage_filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ similarity_threshold_cutoff.md
similarity_percentile_cutoff.md
recency_filter.md
threshold_cutoff.md
percentile_cutoff.md
```
22 changes: 22 additions & 0 deletions docs/source/nodes/passage_filter/percentile_cutoff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Percentile Cutoff

This module is inspired by
our [similarity percentile cutoff](https://marker-inc-korea.github.io/AutoRAG/nodes/passage_filter/similarity_percentile_cutoff.html)
module.

Filter out the contents that are below the content's length times percentile.

## **Module Parameters**

- **percentile** : The percentile value to filter out the contents.
This is essential to run the module, so you have to set this parameter.
- **reverse** : If True, the lower the score, the better.
Default is False.

## **Example config.yaml**

```yaml
modules:
- module_type: percentile_cutoff
percentile: 0.6
```
2 changes: 2 additions & 0 deletions sample_config/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ node_lines:
threshold: 2015-01-01
- module_type: threshold_cutoff
threshold: 0.85
- module_type: percentile_cutoff
percentile: 0.6
- node_type: passage_compressor
strategy:
metrics: [retrieval_token_f1, retrieval_token_recall, retrieval_token_precision]
Expand Down
28 changes: 28 additions & 0 deletions tests/autorag/nodes/passagefilter/test_percentile_cutoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from autorag.nodes.passagefilter import percentile_cutoff
from tests.autorag.nodes.passagefilter.test_passage_filter_base import queries_example, contents_example, \
scores_example, ids_example, base_passage_filter_test, project_dir, previous_result, base_passage_filter_node_test


def test_percentile_cutoff():
original_cutoff = percentile_cutoff.__wrapped__
contents, ids, scores = original_cutoff(
queries_example, contents_example, scores_example, ids_example, percentile=0.6)
base_passage_filter_test(contents, ids, scores)
assert scores[0] == [0.8, 0.5]
assert contents[0] == ["Paris is the capital of France.",
"Paris is one of the capital from France. Isn't it?"]


def test_percentile_cutoff_reverse():
original_cutoff = percentile_cutoff.__wrapped__
contents, ids, scores = original_cutoff(
queries_example, contents_example, scores_example, ids_example, percentile=0.6, reverse=True)
base_passage_filter_test(contents, ids, scores)
assert scores[0] == [0.1, 0.1]
assert contents[0] == ["NomaDamas is Great Team", "havertz is suck at soccer"]


def test_percentile_cutoff_node():
result_df = percentile_cutoff(
project_dir=project_dir, previous_result=previous_result, percentile=0.9)
base_passage_filter_node_test(result_df)

0 comments on commit 2dee056

Please sign in to comment.