Skip to content

Commit

Permalink
chore: Update docstrings (#499)
Browse files Browse the repository at this point in the history
* embedders

* generators

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <[email protected]>

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
masci and shadeMe authored Feb 29, 2024
1 parent a631a47 commit 6d1dd7f
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@

@component
class OllamaDocumentEmbedder:
"""
Computes the embeddings of a list of Documents and stores the obtained vectors in the embedding field of each
Document. It uses embedding models compatible with the Ollama Library.
Usage example:
```python
from haystack import Document
from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder
doc = Document(content="What do llamas say once you have thanked them? No probllama!")
document_embedder = OllamaDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
```
"""

def __init__(
self,
model: str = "nomic-embed-text",
Expand All @@ -20,15 +37,16 @@ def __init__(
embedding_separator: str = "\n",
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text"
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
:param model:
The name of the model to use. The model should be available in the running Ollama instance.
:param url:
The URL of the chat endpoint of a running Ollama instance.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature, top_p, and others.
See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
Expand All @@ -44,15 +62,12 @@ def __init__(
def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Prepares the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
Expand Down Expand Up @@ -101,12 +116,17 @@ def _embed_batch(
@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a provided documents.
:param documents: Documents to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
Runs an Ollama Model to compute embeddings of the provided documents.
:param documents:
Documents to be converted to an embedding.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: Documents with embedding information attached and metadata in a dictionary
:returns: A dictionary with the following keys:
- `documents`: Documents with embedding information attached
- `meta`: The metadata collected during the embedding process
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@

@component
class OllamaTextEmbedder:
"""
Computes the embeddings of a list of Documents and stores the obtained vectors in the embedding field of
each Document. It uses embedding models compatible with the Ollama Library.
Usage example:
```python
from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder
embedder = OllamaTextEmbedder()
result = embedder.run(text="What do llamas say once you have thanked them? No probllama!")
print(result['embedding'])
```
"""

def __init__(
self,
model: str = "nomic-embed-text",
Expand All @@ -14,15 +28,16 @@ def __init__(
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "nomic-embed-text". "https://ollama.com/library/nomic-embed-text"
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/embeddings".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
:param model:
The name of the model to use. The model should be available in the running Ollama instance.
:param url:
The URL of the chat endpoint of a running Ollama instance.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
"""
self.timeout = timeout
self.generation_kwargs = generation_kwargs or {}
Expand All @@ -32,21 +47,23 @@ def __init__(
def _create_json_payload(self, text: str, generation_kwargs: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param text: Text that is to be converted to an embedding
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
return {"model": self.model, "prompt": text, "options": {**self.generation_kwargs, **(generation_kwargs or {})}}

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None):
"""
Run an Ollama Model on a given chat history.
:param text: Text to be converted to an embedding.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
Runs an Ollama Model to compute embeddings of the provided text.
:param text:
Text to be converted to an embedding.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary with the key "embedding" and a list of floats as the value
:returns: A dictionary with the following keys:
- `embedding`: The computed embeddings
- `meta`: The metadata collected during the embedding process
"""

payload = self._create_json_payload(text, generation_kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,26 @@
@component
class OllamaChatGenerator:
"""
Chat Generator based on Ollama. Ollama is a library for easily running LLMs locally.
This component provides an interface to generate text using a LLM running in Ollama.
Supports models running on Ollama, such as llama2 and mixtral. Find the full list of supported models
[here](https://ollama.ai/library).
Usage example:
```python
from haystack_integrations.components.generators.ollama import OllamaChatGenerator
from haystack.dataclasses import ChatMessage
generator = OllamaChatGenerator(model="zephyr",
url = "http://localhost:11434/api/chat",
generation_kwargs={
"num_predict": 100,
"temperature": 0.9,
})
messages = [ChatMessage.from_system("\nYou are a helpful, respectful and honest assistant"),
ChatMessage.from_user("What's Natural Language Processing?")]
print(generator.run(messages=messages))
```
"""

def __init__(
Expand All @@ -22,16 +40,18 @@ def __init__(
timeout: int = 120,
):
"""
:param model: The name of the model to use. The model should be available in the running Ollama instance.
Default is "orca-mini".
:param url: The URL of the chat endpoint of a running Ollama instance.
Default is "http://localhost:11434/api/chat".
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
:param model:
The name of the model to use. The model should be available in the running Ollama instance.
:param url:
The URL of the chat endpoint of a running Ollama instance.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, and others. See the available arguments in
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:param template: The full prompt template (overrides what is defined in the Ollama Modelfile).
:param timeout: The number of seconds before throwing a timeout error from the Ollama API.
Default is 120 seconds.
:param template:
The full prompt template (overrides what is defined in the Ollama Modelfile).
:param timeout:
The number of seconds before throwing a timeout error from the Ollama API.
"""

self.timeout = timeout
Expand All @@ -46,9 +66,6 @@ def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]:
def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=None) -> Dict[str, Any]:
"""
Returns A dictionary of JSON arguments for a POST request to an Ollama service
:param messages: A history/list of chat messages
:param generation_kwargs:
:return: A dictionary of arguments for a POST request to an Ollama service
"""
generation_kwargs = generation_kwargs or {}
return {
Expand All @@ -62,8 +79,6 @@ def _create_json_payload(self, messages: List[ChatMessage], generation_kwargs=No
def _build_message_from_ollama_response(self, ollama_response: Response) -> ChatMessage:
"""
Converts the non-streaming response from the Ollama API to a ChatMessage.
:param ollama_response: The completion returned by the Ollama API.
:return: The ChatMessage.
"""
json_content = ollama_response.json()
message = ChatMessage.from_assistant(content=json_content["message"]["content"])
Expand All @@ -77,12 +92,16 @@ def run(
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Run an Ollama Model on a given chat history.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Optional arguments to pass to the Ollama generation endpoint, such as temperature,
Runs an Ollama Model on a given chat history.
:param messages:
A list of ChatMessage instances representing the input messages.
:param generation_kwargs:
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
top_p, etc. See the
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
:return: A dictionary of the replies containing their metadata
:returns: A dictionary with the following keys:
- `replies`: The responses from the model
"""
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}

Expand Down
Loading

0 comments on commit 6d1dd7f

Please sign in to comment.