Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model_name_or_path > model #418

Merged
merged 8 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
p.add_component(instance=DocumentCleaner(), name="cleaner")
p.add_component(instance=DocumentSplitter(split_by="word", split_length=150, split_overlap=30), name="splitter")
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer")
Expand All @@ -63,7 +63,7 @@
# Create a querying pipeline on the indexed data
q = Pipeline()
q.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
q.add_component("retriever", AstraEmbeddingRetriever(document_store))
Expand Down
4 changes: 2 additions & 2 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]
p = Pipeline()
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer")
Expand All @@ -74,7 +74,7 @@
# Construct rag pipeline
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
rag_pipeline.add_component(instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@ class _InstructorEmbeddingBackendFactory:
_instances: ClassVar[Dict[str, "_InstructorEmbeddingBackend"]] = {}

@staticmethod
def get_embedding_backend(model_name_or_path: str, device: Optional[str] = None, token: Optional[Secret] = None):
embedding_backend_id = f"{model_name_or_path}{device}{token}"
def get_embedding_backend(model: str, device: Optional[str] = None, token: Optional[Secret] = None):
embedding_backend_id = f"{model}{device}{token}"

if embedding_backend_id in _InstructorEmbeddingBackendFactory._instances:
return _InstructorEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _InstructorEmbeddingBackend(
model_name_or_path=model_name_or_path, device=device, token=token
)
embedding_backend = _InstructorEmbeddingBackend(model=model, device=device, token=token)
_InstructorEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -33,9 +31,9 @@ class _InstructorEmbeddingBackend:
Class to manage INSTRUCTOR embeddings.
"""

def __init__(self, model_name_or_path: str, device: Optional[str] = None, token: Optional[Secret] = None):
def __init__(self, model: str, device: Optional[str] = None, token: Optional[Secret] = None):
self.model = INSTRUCTOR(
model_name_or_path=model_name_or_path,
model_name_or_path=model,
device=device,
use_auth_token=token.resolve_value() if token else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name_or_path = model
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -113,7 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token.to_dict() if self.token else None,
instruction=self.instruction,
Expand All @@ -138,7 +138,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, token=self.token
model=self.model, device=self.device, token=self.token
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
:param normalize_embeddings: If set to true, returned vectors will have the length of 1.
"""

self.model_name_or_path = model
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -81,7 +81,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token.to_dict() if self.token else None,
instruction=self.instruction,
Expand All @@ -104,7 +104,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, token=self.token
model=self.model, device=self.device, token=self.token
)

@component.output_types(embedding=List[float])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
)
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-large", device="cpu"
model="hkunlp/instructor-large", device="cpu"
)
same_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend("hkunlp/instructor-large", "cpu")
another_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base", device="cpu"
model="hkunlp/instructor-base", device="cpu"
)

assert same_embedding_backend is embedding_backend
Expand All @@ -30,7 +30,7 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001
)
def test_model_initialization(mock_instructor):
_InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base", device="cpu", token=Secret.from_token("fake-api-token")
model="hkunlp/instructor-base", device="cpu", token=Secret.from_token("fake-api-token")
)
mock_instructor.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="fake-api-token"
Expand All @@ -43,9 +43,7 @@ def test_model_initialization(mock_instructor):
"haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR"
)
def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base"
)
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(model="hkunlp/instructor-base")

data = [["instruction", "sentence1"], ["instruction", "sentence2"]]
embedding_backend.embed(data=data, normalize_embeddings=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_init_default(self):
Test default initialization parameters for InstructorDocumentEmbedder.
"""
embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-base")
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the document"
Expand All @@ -38,7 +38,7 @@ def test_init_with_parameters(self):
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_from_dict(self):
},
}
embedder = InstructorDocumentEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_from_dict_with_custom_init_parameters(self):
},
}
embedder = InstructorDocumentEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the financial document for retrieval"
Expand All @@ -168,7 +168,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base",
model="hkunlp/instructor-base",
device="cpu",
token=Secret.from_env_var("HF_API_TOKEN", strict=False),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_init_default(self):
Test default initialization parameters for InstructorTextEmbedder.
"""
embedder = InstructorTextEmbedder(model="hkunlp/instructor-base")
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the sentence"
Expand All @@ -33,7 +33,7 @@ def test_init_with_parameters(self):
progress_bar=False,
normalize_embeddings=True,
)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_from_dict(self):
},
}
embedder = InstructorTextEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand All @@ -128,7 +128,7 @@ def test_from_dict_with_custom_init_parameters(self):
},
}
embedder = InstructorTextEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the financial document for retrieval"
Expand All @@ -147,7 +147,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base",
model="hkunlp/instructor-base",
device="cpu",
token=Secret.from_env_var("HF_API_TOKEN", strict=False),
)
Expand Down
4 changes: 2 additions & 2 deletions integrations/llama_cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Index the documents to the `InMemoryDocumentStore` using the `SentenceTransforme

```python
doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
doc_embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

# Indexing Pipeline
indexing_pipeline = Pipeline()
Expand All @@ -188,7 +188,7 @@ GPT4 Correct Assistant:

rag_pipeline = Pipeline()

text_embedder = SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

# Load the LLM using LlamaCppGenerator
model_path = "openchat-3.5-1210.Q3_K_S.gguf"
Expand Down
4 changes: 2 additions & 2 deletions integrations/llama_cpp/examples/rag_pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
]

doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
doc_embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")


# Indexing Pipeline
Expand All @@ -47,7 +47,7 @@
"""
rag_pipeline = Pipeline()

text_embedder = SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

model_path = "openchat-3.5-1210.Q3_K_S.gguf"
generator = LlamaCppGenerator(model_path=model_path, n_ctx=4096, n_batch=128)
Expand Down