Skip to content

Commit

Permalink
remove embedding_model from kwargs for passage filter module (#1043)
Browse files Browse the repository at this point in the history
* remove embedding_model from kwargs for passage filter module

* change pop to get

* Use pop params to reduce error at similarity percentile cutoff

* use pop_params at similarity threshold cutoff

---------

Co-authored-by: Bwook (Byoungwook) Kim <[email protected]>
Co-authored-by: Jeffrey (Dongkyu) Kim <[email protected]>
  • Loading branch information
3 people authored Dec 12, 2024
1 parent aa0bfbf commit 94f51d7
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
3 changes: 2 additions & 1 deletion autorag/nodes/generator/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __del__(self):

if torch.cuda.is_available():
from vllm.distributed.parallel_state import (
destroy_model_parallel, destroy_distributed_environment
destroy_model_parallel,
destroy_distributed_environment,
)

destroy_model_parallel()
Expand Down
9 changes: 5 additions & 4 deletions autorag/nodes/passagefilter/similarity_percentile_cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
embedding_query_content,
)
from autorag.utils import result_to_dataframe
from autorag.utils.util import empty_cuda_cache
from autorag.utils.util import empty_cuda_cache, pop_params


class SimilarityPercentileCutoff(BasePassageFilter):
Expand All @@ -21,7 +21,7 @@ def __init__(self, project_dir: Union[str, Path], *args, **kwargs):
:param project_dir: The project directory to use for initializing the module
:param embedding_model: The embedding model string to use for calculating similarity
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
"""
super().__init__(project_dir, *args, **kwargs)
embedding_model_str = kwargs.pop("embedding_model", "openai")
Expand All @@ -34,9 +34,10 @@ def __del__(self):
empty_cuda_cache()

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
def pure(self, previous_result: pd.DataFrame, **kwargs):
queries, contents, scores, ids = self.cast_to_run(previous_result)
return self._pure(queries, contents, scores, ids, *args, **kwargs)
kwargs = pop_params(self._pure, kwargs)
return self._pure(queries, contents, scores, ids, **kwargs)

def _pure(
self,
Expand Down
8 changes: 5 additions & 3 deletions autorag/nodes/passagefilter/similarity_threshold_cutoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from autorag.nodes.passagefilter.base import BasePassageFilter
from autorag.utils.util import (
embedding_query_content,
result_to_dataframe,
empty_cuda_cache,
result_to_dataframe,
pop_params,
)


Expand All @@ -20,10 +21,10 @@ def __init__(self, project_dir: str, *args, **kwargs):
:param project_dir: The project directory to use for initializing the module
:param embedding_model: The embedding model string to use for calculating similarity
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
Default is "openai" which is OpenAI text-embedding-ada-002 embedding model.
"""
super().__init__(project_dir, *args, **kwargs)
embedding_model_str = kwargs.pop("embedding_model", "openai")
embedding_model_str = kwargs.get("embedding_model", "openai")
self.embedding_model = embedding_models[embedding_model_str]()

def __del__(self):
Expand All @@ -33,6 +34,7 @@ def __del__(self):

@result_to_dataframe(["retrieved_contents", "retrieved_ids", "retrieve_scores"])
def pure(self, previous_result: pd.DataFrame, *args, **kwargs):
kwargs = pop_params(self._pure, kwargs)
queries, contents, scores, ids = self.cast_to_run(previous_result)
return self._pure(queries, contents, scores, ids, *args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions autorag/vectordb/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
user: str = "",
password: str = "",
timeout: Optional[float] = None,
params: Dict[str, Any] = {},
params: Dict[str, Any] = {},
):
super().__init__(embedding_model, similarity_metric, embedding_batch)

Expand All @@ -49,7 +49,7 @@ def __init__(
self.timeout = timeout
self.params = params
self.index_type = index_type

# Set Collection
if not utility.has_collection(collection_name, timeout=timeout):
# Get the dimension of the embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def test_similarity_percentile_cutoff(similarity_percentile_cutoff_instance):
)
def test_similarity_percentile_cutoff_node():
result_df = SimilarityPercentileCutoff.run_evaluator(
project_dir=project_dir, previous_result=previous_result, percentile=0.9
project_dir=project_dir,
previous_result=previous_result,
percentile=0.9,
embedding_model="openai_embed_3_large",
)
base_passage_filter_node_test(result_df)
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def test_similarity_threshold_cutoff(similarity_threshold_cutoff_instance):
)
def test_similarity_threshold_cutoff_node():
result_df = SimilarityThresholdCutoff.run_evaluator(
project_dir=project_dir, previous_result=previous_result, threshold=0.9
project_dir=project_dir,
previous_result=previous_result,
threshold=0.9,
embedding_model="openai_embed_3_large",
marker="big-boy",
)
base_passage_filter_node_test(result_df)

0 comments on commit 94f51d7

Please sign in to comment.