-
-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* feat: dynamic embed init * feat: apply embeddingbase class to vectordb * feat: add tests and optimize codes * fix: update import path of embedding_models from autorag to autorag.embedding.base * fix: import err * fix: autorag.embedding_models -> autorag.embedding.base.embedding_models * simplify and delete _check_one_item function and use assertion * return LazyInit at load_from_str and use initialized embedding model at vectordb class * add docs for newer embedding model configuration --------- Co-authored-by: Um Changyong <[email protected]> Co-authored-by: Jeffrey (Dongkyu) Kim <[email protected]>
- Loading branch information
1 parent
0f6bba4
commit 1f093d3
Showing
15 changed files
with
235 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import logging | ||
import sys | ||
|
||
from random import random | ||
from typing import List, Union, Dict | ||
|
||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding | ||
from llama_index.embeddings.openai import OpenAIEmbedding | ||
from llama_index.embeddings.openai import OpenAIEmbeddingModelType | ||
from langchain_openai.embeddings import OpenAIEmbeddings | ||
|
||
from autorag import LazyInit | ||
|
||
logger = logging.getLogger("AutoRAG") | ||
|
||
|
||
class MockEmbeddingRandom(MockEmbedding): | ||
"""Mock embedding with random vectors.""" | ||
|
||
def _get_vector(self) -> List[float]: | ||
return [random() for _ in range(self.embed_dim)] | ||
|
||
|
||
embedding_models = { | ||
# llama index | ||
"openai": LazyInit( | ||
OpenAIEmbedding | ||
), # default model is OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002 | ||
"openai_embed_3_large": LazyInit( | ||
OpenAIEmbedding, model_name=OpenAIEmbeddingModelType.TEXT_EMBED_3_LARGE | ||
), | ||
"openai_embed_3_small": LazyInit( | ||
OpenAIEmbedding, model_name=OpenAIEmbeddingModelType.TEXT_EMBED_3_SMALL | ||
), | ||
"mock": LazyInit(MockEmbeddingRandom, embed_dim=768), | ||
# langchain | ||
"openai_langchain": LazyInit(OpenAIEmbeddings), | ||
} | ||
|
||
try: | ||
# you can use your own model in this way. | ||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | ||
|
||
embedding_models["huggingface_baai_bge_small"] = LazyInit( | ||
HuggingFaceEmbedding, model_name="BAAI/bge-small-en-v1.5" | ||
) | ||
embedding_models["huggingface_cointegrated_rubert_tiny2"] = LazyInit( | ||
HuggingFaceEmbedding, model_name="cointegrated/rubert-tiny2" | ||
) | ||
embedding_models["huggingface_all_mpnet_base_v2"] = LazyInit( | ||
HuggingFaceEmbedding, | ||
model_name="sentence-transformers/all-mpnet-base-v2", | ||
max_length=512, | ||
) | ||
embedding_models["huggingface_bge_m3"] = LazyInit( | ||
HuggingFaceEmbedding, model_name="BAAI/bge-m3" | ||
) | ||
embedding_models["huggingface_multilingual_e5_large"] = LazyInit( | ||
HuggingFaceEmbedding, model_name="intfloat/multilingual-e5-large-instruct" | ||
) | ||
except ImportError: | ||
logger.info( | ||
"You are using API version of AutoRAG." | ||
"To use local version, run pip install 'AutoRAG[gpu]'" | ||
) | ||
|
||
|
||
class EmbeddingModel: | ||
@staticmethod | ||
def load(config: Union[str, List[Dict]]): | ||
if isinstance(config, str): | ||
return EmbeddingModel.load_from_str(config) | ||
elif isinstance(config, list): | ||
return EmbeddingModel.load_from_dict(config) | ||
else: | ||
raise ValueError("Invalid type of config") | ||
|
||
@staticmethod | ||
def load_from_str(name: str): | ||
try: | ||
return embedding_models[name] | ||
except KeyError: | ||
raise ValueError(f"Embedding model '{name}' is not supported") | ||
|
||
@staticmethod | ||
def load_from_dict(option: List[dict]): | ||
def _check_keys(target: dict): | ||
if "type" not in target or "model_name" not in target: | ||
raise ValueError("Both 'type' and 'model_name' must be provided") | ||
if target["type"] not in ["openai", "huggingface", "mock"]: | ||
raise ValueError( | ||
f"Embedding model type '{target['type']}' is not supported" | ||
) | ||
|
||
def _get_huggingface_class(): | ||
module = sys.modules.get("llama_index.embeddings.huggingface") | ||
if not module: | ||
logger.info( | ||
"You are using API version of AutoRAG. " | ||
"To use local version, run `pip install 'AutoRAG[gpu]'`." | ||
) | ||
return None | ||
return getattr(module, "HuggingFaceEmbedding", None) | ||
|
||
if len(option) != 1: | ||
raise ValueError("Only one embedding model is supported") | ||
_check_keys(option[0]) | ||
|
||
model_options = option[0] | ||
model_type = model_options.pop("type") | ||
|
||
embedding_map = { | ||
"openai": OpenAIEmbedding, | ||
"mock": MockEmbeddingRandom, | ||
"huggingface": _get_huggingface_class(), | ||
} | ||
|
||
embedding_class = embedding_map.get(model_type) | ||
if not embedding_class: | ||
raise ValueError(f"Embedding model type '{model_type}' is not supported") | ||
|
||
return LazyInit(embedding_class, **model_options) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import pytest | ||
from llama_index.embeddings.openai import OpenAIEmbedding | ||
|
||
from autorag.embedding.base import EmbeddingModel, MockEmbeddingRandom | ||
|
||
|
||
def test_load_embedding_model(): | ||
embedding = EmbeddingModel.load("mock") | ||
assert embedding is not None | ||
assert isinstance(embedding(), MockEmbeddingRandom) | ||
|
||
embedding = EmbeddingModel.load( | ||
[{"type": "openai", "model_name": "text-embedding-ada-002"}] | ||
) | ||
assert embedding is not None | ||
assert isinstance(embedding(), OpenAIEmbedding) | ||
|
||
|
||
def test_load_from_str_embedding_model(): | ||
# Test loading a supported embedding model | ||
embedding = EmbeddingModel.load_from_str("mock") | ||
assert embedding is not None | ||
assert isinstance(embedding(), MockEmbeddingRandom) | ||
|
||
# Test loading an unsupported embedding model | ||
with pytest.raises( | ||
ValueError, match="Embedding model 'unsupported_model' is not supported" | ||
): | ||
EmbeddingModel.load_from_str("unsupported_model") | ||
|
||
|
||
def test_load_embedding_model_from_dict(): | ||
# Test loading with missing keys | ||
with pytest.raises( | ||
ValueError, match="Both 'type' and 'model_name' must be provided" | ||
): | ||
EmbeddingModel.load_from_dict([{"type": "openai"}]) | ||
|
||
# Test loading with an unsupported type | ||
with pytest.raises( | ||
ValueError, match="Embedding model type 'unsupported_type' is not supported" | ||
): | ||
EmbeddingModel.load_from_dict( | ||
[{"type": "unsupported_type", "model_name": "some-model"}] | ||
) | ||
|
||
# Test loading with multiple items | ||
with pytest.raises(ValueError, match="Only one embedding model is supported"): | ||
EmbeddingModel.load_from_dict( | ||
[ | ||
{"type": "openai", "model_name": "text-embedding-ada-002"}, | ||
{"type": "huggingface", "model_name": "BAAI/bge-small-en-v1.5"}, | ||
] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.