Skip to content

Commit

Permalink
model_name_or_path > model (#418)
Browse files Browse the repository at this point in the history
* instructor - new secret management

* fix coverage

* retry coverage

* model_name_or_path > model

* linting

* too much renaming :-)

* fix
  • Loading branch information
anakin87 authored Feb 15, 2024
1 parent 999e5a2 commit d0ffe8b
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 37 deletions.
4 changes: 2 additions & 2 deletions integrations/astra/examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
p.add_component(instance=DocumentCleaner(), name="cleaner")
p.add_component(instance=DocumentSplitter(split_by="word", split_length=150, split_overlap=30), name="splitter")
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer")
Expand All @@ -63,7 +63,7 @@
# Create a querying pipeline on the indexed data
q = Pipeline()
q.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
q.add_component("retriever", AstraEmbeddingRetriever(document_store))
Expand Down
4 changes: 2 additions & 2 deletions integrations/astra/examples/pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
]
p = Pipeline()
p.add_component(
instance=SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
p.add_component(instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="writer")
Expand All @@ -74,7 +74,7 @@
# Construct rag pipeline
rag_pipeline = Pipeline()
rag_pipeline.add_component(
instance=SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2"),
instance=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
name="embedder",
)
rag_pipeline.add_component(instance=AstraEmbeddingRetriever(document_store=document_store), name="retriever")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@ class _InstructorEmbeddingBackendFactory:
_instances: ClassVar[Dict[str, "_InstructorEmbeddingBackend"]] = {}

@staticmethod
def get_embedding_backend(model_name_or_path: str, device: Optional[str] = None, token: Optional[Secret] = None):
embedding_backend_id = f"{model_name_or_path}{device}{token}"
def get_embedding_backend(model: str, device: Optional[str] = None, token: Optional[Secret] = None):
embedding_backend_id = f"{model}{device}{token}"

if embedding_backend_id in _InstructorEmbeddingBackendFactory._instances:
return _InstructorEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _InstructorEmbeddingBackend(
model_name_or_path=model_name_or_path, device=device, token=token
)
embedding_backend = _InstructorEmbeddingBackend(model=model, device=device, token=token)
_InstructorEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend

Expand All @@ -33,9 +31,9 @@ class _InstructorEmbeddingBackend:
Class to manage INSTRUCTOR embeddings.
"""

def __init__(self, model_name_or_path: str, device: Optional[str] = None, token: Optional[Secret] = None):
def __init__(self, model: str, device: Optional[str] = None, token: Optional[Secret] = None):
self.model = INSTRUCTOR(
model_name_or_path=model_name_or_path,
model_name_or_path=model,
device=device,
use_auth_token=token.resolve_value() if token else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""

self.model_name_or_path = model
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -113,7 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token.to_dict() if self.token else None,
instruction=self.instruction,
Expand All @@ -138,7 +138,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, token=self.token
model=self.model, device=self.device, token=self.token
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
:param normalize_embeddings: If set to true, returned vectors will have the length of 1.
"""

self.model_name_or_path = model
self.model = model
# TODO: remove device parameter and use Haystack's device management once migrated
self.device = device or "cpu"
self.token = token
Expand All @@ -81,7 +81,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model=self.model_name_or_path,
model=self.model,
device=self.device,
token=self.token.to_dict() if self.token else None,
instruction=self.instruction,
Expand All @@ -104,7 +104,7 @@ def warm_up(self):
"""
if not hasattr(self, "embedding_backend"):
self.embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path=self.model_name_or_path, device=self.device, token=self.token
model=self.model, device=self.device, token=self.token
)

@component.output_types(embedding=List[float])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
)
def test_factory_behavior(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-large", device="cpu"
model="hkunlp/instructor-large", device="cpu"
)
same_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend("hkunlp/instructor-large", "cpu")
another_embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base", device="cpu"
model="hkunlp/instructor-base", device="cpu"
)

assert same_embedding_backend is embedding_backend
Expand All @@ -30,7 +30,7 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001
)
def test_model_initialization(mock_instructor):
_InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base", device="cpu", token=Secret.from_token("fake-api-token")
model="hkunlp/instructor-base", device="cpu", token=Secret.from_token("fake-api-token")
)
mock_instructor.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="fake-api-token"
Expand All @@ -43,9 +43,7 @@ def test_model_initialization(mock_instructor):
"haystack_integrations.components.embedders.instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR"
)
def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(
model_name_or_path="hkunlp/instructor-base"
)
embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend(model="hkunlp/instructor-base")

data = [["instruction", "sentence1"], ["instruction", "sentence2"]]
embedding_backend.embed(data=data, normalize_embeddings=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_init_default(self):
Test default initialization parameters for InstructorDocumentEmbedder.
"""
embedder = InstructorDocumentEmbedder(model="hkunlp/instructor-base")
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the document"
Expand All @@ -38,7 +38,7 @@ def test_init_with_parameters(self):
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_from_dict(self):
},
}
embedder = InstructorDocumentEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_from_dict_with_custom_init_parameters(self):
},
}
embedder = InstructorDocumentEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the financial document for retrieval"
Expand All @@ -168,7 +168,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base",
model="hkunlp/instructor-base",
device="cpu",
token=Secret.from_env_var("HF_API_TOKEN", strict=False),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_init_default(self):
Test default initialization parameters for InstructorTextEmbedder.
"""
embedder = InstructorTextEmbedder(model="hkunlp/instructor-base")
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the sentence"
Expand All @@ -33,7 +33,7 @@ def test_init_with_parameters(self):
progress_bar=False,
normalize_embeddings=True,
)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_token("fake-api-token")
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_from_dict(self):
},
}
embedder = InstructorTextEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
Expand All @@ -128,7 +128,7 @@ def test_from_dict_with_custom_init_parameters(self):
},
}
embedder = InstructorTextEmbedder.from_dict(embedder_dict)
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.model == "hkunlp/instructor-base"
assert embedder.device == "cuda"
assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert embedder.instruction == "Represent the financial document for retrieval"
Expand All @@ -147,7 +147,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name_or_path="hkunlp/instructor-base",
model="hkunlp/instructor-base",
device="cpu",
token=Secret.from_env_var("HF_API_TOKEN", strict=False),
)
Expand Down
4 changes: 2 additions & 2 deletions integrations/llama_cpp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Index the documents to the `InMemoryDocumentStore` using the `SentenceTransforme

```python
doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
doc_embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

# Indexing Pipeline
indexing_pipeline = Pipeline()
Expand All @@ -188,7 +188,7 @@ GPT4 Correct Assistant:

rag_pipeline = Pipeline()

text_embedder = SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

# Load the LLM using LlamaCppGenerator
model_path = "openchat-3.5-1210.Q3_K_S.gguf"
Expand Down
4 changes: 2 additions & 2 deletions integrations/llama_cpp/examples/rag_pipeline_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
]

doc_store = InMemoryDocumentStore(embedding_similarity_function="cosine")
doc_embedder = SentenceTransformersDocumentEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")


# Indexing Pipeline
Expand All @@ -47,7 +47,7 @@
"""
rag_pipeline = Pipeline()

text_embedder = SentenceTransformersTextEmbedder(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
text_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")

model_path = "openchat-3.5-1210.Q3_K_S.gguf"
generator = LlamaCppGenerator(model_path=model_path, n_ctx=4096, n_batch=128)
Expand Down

0 comments on commit d0ffe8b

Please sign in to comment.